From 188ede28e046c911cb8e604fd1adc2b5cc1f264b Mon Sep 17 00:00:00 2001 From: Sunho Kim Date: Thu, 25 Apr 2024 17:41:21 -0700 Subject: [PATCH] [ORC] Implement basic reoptimization. --- compiler-rt/lib/orc/common.h | 6 +- compiler-rt/lib/orc/elfnix_platform.cpp | 1 + .../Orc/JITLinkRedirectableSymbolManager.h | 31 +- .../ExecutionEngine/Orc/ReOptimizeLayer.h | 181 ++++++++++++ .../ExecutionEngine/Orc/RedirectionManager.h | 2 + llvm/lib/ExecutionEngine/Orc/CMakeLists.txt | 1 + .../Orc/JITLinkRedirectableSymbolManager.cpp | 19 +- .../ExecutionEngine/Orc/ReOptimizeLayer.cpp | 279 ++++++++++++++++++ .../ExecutionEngine/Orc/CMakeLists.txt | 1 + .../Orc/JITLinkRedirectionManagerTest.cpp | 7 +- .../Orc/ReOptimizeLayerTest.cpp | 152 ++++++++++ 11 files changed, 650 insertions(+), 30 deletions(-) create mode 100644 llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h create mode 100644 llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp create mode 100644 llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp diff --git a/compiler-rt/lib/orc/common.h b/compiler-rt/lib/orc/common.h index 73c5c4a2bd8d47..f34229a615341d 100644 --- a/compiler-rt/lib/orc/common.h +++ b/compiler-rt/lib/orc/common.h @@ -19,9 +19,9 @@ /// This macro should be used to define tags that will be associated with /// handlers in the JIT process, and call can be used to define tags f -#define ORC_RT_JIT_DISPATCH_TAG(X) \ -extern "C" char X; \ -char X = 0; +#define ORC_RT_JIT_DISPATCH_TAG(X) \ + ORC_RT_INTERFACE char X; \ + char X = 0; /// Opaque struct for external symbols. struct __orc_rt_Opaque {}; diff --git a/compiler-rt/lib/orc/elfnix_platform.cpp b/compiler-rt/lib/orc/elfnix_platform.cpp index 57673f088f77cb..24cc6e1ef11778 100644 --- a/compiler-rt/lib/orc/elfnix_platform.cpp +++ b/compiler-rt/lib/orc/elfnix_platform.cpp @@ -30,6 +30,7 @@ using namespace orc_rt; using namespace orc_rt::elfnix; // Declare function tags for functions in the JIT process. +ORC_RT_JIT_DISPATCH_TAG(__orc_rt_reoptimize_tag) ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_push_initializers_tag) ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_symbol_lookup_tag) diff --git a/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h index 5de0da1f52d0db..52f284c89bdade 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h @@ -25,11 +25,10 @@ class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager, public: /// Create redirection manager that uses JITLink based implementaion. static Expected> - Create(ExecutionSession &ES, ObjectLinkingLayer &ObjLinkingLayer, - JITDylib &JD) { + Create(ObjectLinkingLayer &ObjLinkingLayer, JITDylib &JD) { Error Err = Error::success(); auto RM = std::unique_ptr( - new JITLinkRedirectableSymbolManager(ES, ObjLinkingLayer, JD, Err)); + new JITLinkRedirectableSymbolManager(ObjLinkingLayer, JD, Err)); if (Err) return Err; return std::move(RM); @@ -53,30 +52,33 @@ class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager, constexpr static StringRef JumpStubTableName = "$IND_JUMP_"; constexpr static StringRef StubPtrTableName = "$__IND_JUMP_PTRS"; - JITLinkRedirectableSymbolManager(ExecutionSession &ES, - ObjectLinkingLayer &ObjLinkingLayer, + JITLinkRedirectableSymbolManager(ObjectLinkingLayer &ObjLinkingLayer, JITDylib &JD, Error &Err) - : ES(ES), ObjLinkingLayer(ObjLinkingLayer), JD(JD), - AnonymousPtrCreator( - jitlink::getAnonymousPointerCreator(ES.getTargetTriple())), - PtrJumpStubCreator( - jitlink::getPointerJumpStubCreator(ES.getTargetTriple())) { + : ObjLinkingLayer(ObjLinkingLayer), JD(JD), + AnonymousPtrCreator(jitlink::getAnonymousPointerCreator( + ObjLinkingLayer.getExecutionSession().getTargetTriple())), + PtrJumpStubCreator(jitlink::getPointerJumpStubCreator( + ObjLinkingLayer.getExecutionSession().getTargetTriple())) { if (!AnonymousPtrCreator || !PtrJumpStubCreator) Err = make_error("Architecture not supported", inconvertibleErrorCode()); if (Err) return; - ES.registerResourceManager(*this); + ObjLinkingLayer.getExecutionSession().registerResourceManager(*this); } - ~JITLinkRedirectableSymbolManager() { ES.deregisterResourceManager(*this); } + ~JITLinkRedirectableSymbolManager() { + ObjLinkingLayer.getExecutionSession().deregisterResourceManager(*this); + } StringRef JumpStubSymbolName(unsigned I) { - return *ES.intern((JumpStubPrefix + Twine(I)).str()); + return *ObjLinkingLayer.getExecutionSession().intern( + (JumpStubPrefix + Twine(I)).str()); } StringRef StubPtrSymbolName(unsigned I) { - return *ES.intern((StubPtrPrefix + Twine(I)).str()); + return *ObjLinkingLayer.getExecutionSession().intern( + (StubPtrPrefix + Twine(I)).str()); } unsigned GetNumAvailableStubs() const { return AvailableStubs.size(); } @@ -84,7 +86,6 @@ class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager, Error redirectInner(JITDylib &TargetJD, const SymbolAddrMap &NewDests); Error grow(unsigned Need); - ExecutionSession &ES; ObjectLinkingLayer &ObjLinkingLayer; JITDylib &JD; jitlink::AnonymousPointerCreator AnonymousPtrCreator; diff --git a/llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h b/llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h new file mode 100644 index 00000000000000..4adc3efad55730 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h @@ -0,0 +1,181 @@ +//===- ReOptimizeLayer.h - Re-optimization layer interface ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Re-optimization layer interface. +// +//===----------------------------------------------------------------------===// +#ifndef LLVM_EXECUTIONENGINE_ORC_REOPTIMIZELAYER_H +#define LLVM_EXECUTIONENGINE_ORC_REOPTIMIZELAYER_H + +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/Layer.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" +#include "llvm/ExecutionEngine/Orc/RedirectionManager.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +namespace llvm { +namespace orc { + +class ReOptimizeLayer : public IRLayer, public ResourceManager { +public: + using ReOptMaterializationUnitID = uint64_t; + + /// AddProfilerFunc will be called when ReOptimizeLayer emits the first + /// version of a materialization unit in order to inject profiling code and + /// reoptimization request code. + using AddProfilerFunc = unique_function; + + /// ReOptimizeFunc will be called when ReOptimizeLayer reoptimization of a + /// materialization unit was requested in order to reoptimize the IR module + /// based on profile data. OldRT is the ResourceTracker that tracks the old + /// function definitions. The OldRT must be kept alive until it can be + /// guaranteed that every invocation of the old function definitions has been + /// terminated. + using ReOptimizeFunc = unique_function; + + ReOptimizeLayer(ExecutionSession &ES, DataLayout &DL, IRLayer &BaseLayer, + RedirectableSymbolManager &RM) + : IRLayer(ES, BaseLayer.getManglingOptions()), ES(ES), Mangle(ES, DL), + BaseLayer(BaseLayer), RSManager(RM), ReOptFunc(identity), + ProfilerFunc(reoptimizeIfCallFrequent) {} + + void setReoptimizeFunc(ReOptimizeFunc ReOptFunc) { + this->ReOptFunc = std::move(ReOptFunc); + } + + void setAddProfilerFunc(AddProfilerFunc ProfilerFunc) { + this->ProfilerFunc = std::move(ProfilerFunc); + } + + /// Registers reoptimize runtime dispatch handlers to given PlatformJD. The + /// reoptimization request will not be handled if dispatch handler is not + /// registered by using this function. + Error reigsterRuntimeFunctions(JITDylib &PlatformJD); + + /// Emits the given module. This should not be called by clients: it will be + /// called by the JIT when a definition added via the add method is requested. + void emit(std::unique_ptr R, + ThreadSafeModule TSM) override; + + static const uint64_t CallCountThreshold = 10; + + /// Basic AddProfilerFunc that reoptimizes the function when the call count + /// exceeds CallCountThreshold. + static Error reoptimizeIfCallFrequent(ReOptimizeLayer &Parent, + ReOptMaterializationUnitID MUID, + unsigned CurVersion, + ThreadSafeModule &TSM); + + static Error identity(ReOptimizeLayer &Parent, + ReOptMaterializationUnitID MUID, unsigned CurVersion, + ResourceTrackerSP OldRT, ThreadSafeModule &TSM) { + return Error::success(); + } + + // Create IR reoptimize request fucntion call. + static void createReoptimizeCall(Module &M, Instruction &IP, + GlobalVariable *ArgBuffer); + + Error handleRemoveResources(JITDylib &JD, ResourceKey K) override; + void handleTransferResources(JITDylib &JD, ResourceKey DstK, + ResourceKey SrcK) override; + +private: + class ReOptMaterializationUnitState { + public: + ReOptMaterializationUnitState() = default; + ReOptMaterializationUnitState(ReOptMaterializationUnitID ID, + ThreadSafeModule TSM) + : ID(ID), TSM(std::move(TSM)) {} + ReOptMaterializationUnitState(ReOptMaterializationUnitState &&Other) + : ID(Other.ID), TSM(std::move(Other.TSM)), RT(std::move(Other.RT)), + Reoptimizing(std::move(Other.Reoptimizing)), + CurVersion(Other.CurVersion) {} + + ReOptMaterializationUnitID getID() { return ID; } + + const ThreadSafeModule &getThreadSafeModule() { return TSM; } + + ResourceTrackerSP getResourceTracker() { + std::unique_lock Lock(Mutex); + return RT; + } + + void setResourceTracker(ResourceTrackerSP RT) { + std::unique_lock Lock(Mutex); + this->RT = RT; + } + + uint32_t getCurVersion() { + std::unique_lock Lock(Mutex); + return CurVersion; + } + + bool tryStartReoptimize(); + void reoptimizeSucceeded(); + void reoptimizeFailed(); + + private: + std::mutex Mutex; + ReOptMaterializationUnitID ID; + ThreadSafeModule TSM; + ResourceTrackerSP RT; + bool Reoptimizing = false; + uint32_t CurVersion = 0; + }; + + using SPSReoptimizeArgList = + shared::SPSArgList; + using SendErrorFn = unique_function; + + Expected emitMUImplSymbols(ReOptMaterializationUnitState &MUState, + uint32_t Version, JITDylib &JD, + ThreadSafeModule TSM); + + void rt_reoptimize(SendErrorFn SendResult, ReOptMaterializationUnitID MUID, + uint32_t CurVersion); + + static Expected + createReoptimizeArgBuffer(Module &M, ReOptMaterializationUnitID MUID, + uint32_t CurVersion); + + ReOptMaterializationUnitState & + createMaterializationUnitState(const ThreadSafeModule &TSM); + + void + registerMaterializationUnitResource(ResourceKey Key, + ReOptMaterializationUnitState &State); + + ReOptMaterializationUnitState & + getMaterializationUnitState(ReOptMaterializationUnitID MUID); + + ExecutionSession &ES; + MangleAndInterner Mangle; + IRLayer &BaseLayer; + RedirectableSymbolManager &RSManager; + + ReOptimizeFunc ReOptFunc; + AddProfilerFunc ProfilerFunc; + + std::mutex Mutex; + std::map MUStates; + DenseMap> MUResources; + ReOptMaterializationUnitID NextID = 1; +}; + +} // namespace orc +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h index 87a81b0e529ccd..4004c42d914684 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h @@ -63,6 +63,8 @@ class RedirectableSymbolManager : public RedirectionManager { const SymbolMap &InitialDests) = 0; }; +/// RedirectableMaterializationUnit materializes redirectable symbol +/// by invoking RedirectableSymbolManager::emitRedirectableSymbols class RedirectableMaterializationUnit : public MaterializationUnit { public: RedirectableMaterializationUnit(RedirectableSymbolManager &RM, diff --git a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt index c07e6293ad1464..008875118fdeff 100644 --- a/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -53,6 +53,7 @@ add_llvm_component_library(LLVMOrcJIT ThreadSafeModule.cpp RedirectionManager.cpp JITLinkRedirectableSymbolManager.cpp + ReOptimizeLayer.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp index 6d76d50271b9f2..4ef217e6c562db 100644 --- a/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp +++ b/llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp @@ -17,6 +17,7 @@ using namespace llvm::orc; void JITLinkRedirectableSymbolManager::emitRedirectableSymbols( std::unique_ptr R, const SymbolAddrMap &InitialDests) { + auto &ES = ObjLinkingLayer.getExecutionSession(); std::unique_lock Lock(Mutex); if (GetNumAvailableStubs() < InitialDests.size()) if (auto Err = grow(InitialDests.size() - GetNumAvailableStubs())) { @@ -37,7 +38,6 @@ void JITLinkRedirectableSymbolManager::emitRedirectableSymbols( R->failMaterialization(); return; } - dbgs() << *K << "\n"; SymbolToStubs[&TargetJD][K] = StubID; NewSymbolDefs[K] = JumpStubs[StubID]; NewSymbolDefs[K].setFlags(V.getFlags()); @@ -45,13 +45,14 @@ void JITLinkRedirectableSymbolManager::emitRedirectableSymbols( AvailableStubs.pop_back(); } - if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) { + // FIXME: when this fails we can return stubs to the pool + if (auto Err = redirectInner(TargetJD, InitialDests)) { ES.reportError(std::move(Err)); R->failMaterialization(); return; } - if (auto Err = redirectInner(TargetJD, InitialDests)) { + if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) { ES.reportError(std::move(Err)); R->failMaterialization(); return; @@ -85,10 +86,10 @@ Error JITLinkRedirectableSymbolManager::redirectInner( StubHandle StubID = SymbolToStubs[&TargetJD].at(K); PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()}); } - if (auto Err = ES.getExecutorProcessControl().getMemoryAccess().writePointers( - PtrWrites)) - return Err; - return Error::success(); + return ObjLinkingLayer.getExecutionSession() + .getExecutorProcessControl() + .getMemoryAccess() + .writePointers(PtrWrites); } Error JITLinkRedirectableSymbolManager::grow(unsigned Need) { @@ -103,16 +104,18 @@ Error JITLinkRedirectableSymbolManager::grow(unsigned Need) { SymbolLookupSet LookupSymbols; DenseMap NewDefsMap; + auto &ES = ObjLinkingLayer.getExecutionSession(); Triple TT = ES.getTargetTriple(); auto G = std::make_unique( "", TT, TT.isArch64Bit() ? 8 : 4, - TT.isLittleEndian() ? support::little : support::big, + TT.isLittleEndian() ? endianness::little : endianness::big, jitlink::getGenericEdgeKindName); auto &PointerSection = G->createSection(StubPtrTableName, MemProt::Write | MemProt::Read); auto &StubsSection = G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read); + // FIXME: We can batch the stubs into one block and use address to access them for (size_t I = OldSize; I < NewSize; I++) { auto Pointer = AnonymousPtrCreator(*G, PointerSection, nullptr, 0); if (auto Err = Pointer.takeError()) diff --git a/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp b/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp new file mode 100644 index 00000000000000..e2669fd1fc86b3 --- /dev/null +++ b/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp @@ -0,0 +1,279 @@ +#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" + +using namespace llvm; +using namespace orc; + +bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() { + std::unique_lock Lock(Mutex); + if (Reoptimizing) + return false; + + Reoptimizing = true; + return true; +} + +void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() { + std::unique_lock Lock(Mutex); + assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); + Reoptimizing = false; + CurVersion++; +} + +void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() { + std::unique_lock Lock(Mutex); + assert(Reoptimizing && "Tried to mark unstarted reoptimization as done"); + Reoptimizing = false; +} + +Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) { + ExecutionSession::JITDispatchHandlerAssociationMap WFs; + using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t); + WFs[Mangle("__orc_rt_reoptimize_tag")] = + ES.wrapAsyncWithSPS(this, + &ReOptimizeLayer::rt_reoptimize); + return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs)); +} + +void ReOptimizeLayer::emit(std::unique_ptr R, + ThreadSafeModule TSM) { + auto &JD = R->getTargetJITDylib(); + + bool HasNonCallable = false; + for (auto &KV : R->getSymbols()) { + auto &Flags = KV.second; + if (!Flags.isCallable()) + HasNonCallable = true; + } + + if (HasNonCallable) { + BaseLayer.emit(std::move(R), std::move(TSM)); + return; + } + + auto &MUState = createMaterializationUnitState(TSM); + + if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) { + registerMaterializationUnitResource(Key, MUState); + })) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } + + if (auto Err = + ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) { + ES.reportError(std::move(Err)); + R->failMaterialization(); + return; + } + + auto InitialDests = + emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM)); + if (!InitialDests) { + ES.reportError(InitialDests.takeError()); + R->failMaterialization(); + return; + } + + RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests)); +} + +Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent, + ReOptMaterializationUnitID MUID, + unsigned CurVersion, + ThreadSafeModule &TSM) { + return TSM.withModuleDo([&](Module &M) -> Error { + Type *I64Ty = Type::getInt64Ty(M.getContext()); + GlobalVariable *Counter = new GlobalVariable( + M, I64Ty, false, GlobalValue::InternalLinkage, + Constant::getNullValue(I64Ty), "__orc_reopt_counter"); + auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion); + if (auto Err = ArgBufferConst.takeError()) + return Err; + GlobalVariable *ArgBuffer = + new GlobalVariable(M, (*ArgBufferConst)->getType(), true, + GlobalValue::InternalLinkage, (*ArgBufferConst)); + for (auto &F : M) { + if (F.isDeclaration()) + continue; + auto &BB = F.getEntryBlock(); + auto *IP = &*BB.getFirstInsertionPt(); + IRBuilder<> IRB(IP); + Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true); + Value *Cnt = IRB.CreateLoad(I64Ty, Counter); + // Use EQ to prevent further reoptimize calls. + Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold); + Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1)); + (void)IRB.CreateStore(Added, Counter); + Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false); + createReoptimizeCall(M, *SplitTerminator, ArgBuffer); + } + return Error::success(); + }); +} + +Expected +ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState, + uint32_t Version, JITDylib &JD, + ThreadSafeModule TSM) { + DenseMap RenamedMap; + cantFail(TSM.withModuleDo([&](Module &M) -> Error { + MangleAndInterner Mangle(ES, M.getDataLayout()); + for (auto &F : M) + if (!F.isDeclaration()) { + std::string NewName = + (F.getName() + ".__def__." + Twine(Version)).str(); + RenamedMap[Mangle(F.getName())] = Mangle(NewName); + F.setName(NewName); + } + return Error::success(); + })); + + auto RT = JD.createResourceTracker(); + if (auto Err = + JD.define(std::make_unique( + BaseLayer, *getManglingOptions(), std::move(TSM)), + RT)) + return Err; + MUState.setResourceTracker(RT); + + SymbolLookupSet LookupSymbols; + for (auto [K, V] : RenamedMap) + LookupSymbols.add(V); + + auto ImplSymbols = + ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols, + LookupKind::Static, SymbolState::Resolved); + if (auto Err = ImplSymbols.takeError()) + return Err; + + SymbolMap Result; + for (auto [K, V] : RenamedMap) + Result[K] = (*ImplSymbols)[V]; + + return Result; +} + +void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult, + ReOptMaterializationUnitID MUID, + uint32_t CurVersion) { + auto &MUState = getMaterializationUnitState(MUID); + if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) { + SendResult(Error::success()); + return; + } + + ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule()); + auto OldRT = MUState.getResourceTracker(); + auto &JD = OldRT->getJITDylib(); + + if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) { + ES.reportError(std::move(Err)); + MUState.reoptimizeFailed(); + SendResult(Error::success()); + return; + } + + auto SymbolDests = + emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM)); + if (!SymbolDests) { + ES.reportError(SymbolDests.takeError()); + MUState.reoptimizeFailed(); + SendResult(Error::success()); + return; + } + + if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) { + ES.reportError(std::move(Err)); + MUState.reoptimizeFailed(); + SendResult(Error::success()); + return; + } + + MUState.reoptimizeSucceeded(); + SendResult(Error::success()); +} + +Expected ReOptimizeLayer::createReoptimizeArgBuffer( + Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) { + size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion); + std::vector ArgBuffer(ArgBufferSize); + shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size()); + if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion)) + return make_error("Could not serealize args list", + inconvertibleErrorCode()); + return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer)); +} + +void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP, + GlobalVariable *ArgBuffer) { + GlobalVariable *DispatchCtx = + M.getGlobalVariable("__orc_rt_jit_dispatch_ctx"); + if (!DispatchCtx) + DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0), + false, GlobalValue::ExternalLinkage, + nullptr, "__orc_rt_jit_dispatch_ctx"); + GlobalVariable *ReoptimizeTag = + M.getGlobalVariable("__orc_rt_reoptimize_tag"); + if (!ReoptimizeTag) + ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0), + false, GlobalValue::ExternalLinkage, + nullptr, "__orc_rt_reoptimize_tag"); + Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch"); + if (!DispatchFunc) { + std::vector Args = {PointerType::get(M.getContext(), 0), + PointerType::get(M.getContext(), 0), + PointerType::get(M.getContext(), 0), + IntegerType::get(M.getContext(), 64)}; + FunctionType *FuncTy = + FunctionType::get(Type::getVoidTy(M.getContext()), Args, false); + DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage, + "__orc_rt_jit_dispatch", &M); + } + size_t ArgBufferSizeConst = + SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{}); + Constant *ArgBufferSize = ConstantInt::get( + IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false); + IRBuilder<> IRB(&IP); + (void)IRB.CreateCall(DispatchFunc, + {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize}); +} + +ReOptimizeLayer::ReOptMaterializationUnitState & +ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) { + std::unique_lock Lock(Mutex); + ReOptMaterializationUnitID MUID = NextID; + MUStates.emplace(MUID, + ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM))); + ++NextID; + return MUStates.at(MUID); +} + +ReOptimizeLayer::ReOptMaterializationUnitState & +ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) { + std::unique_lock Lock(Mutex); + return MUStates.at(MUID); +} + +void ReOptimizeLayer::registerMaterializationUnitResource( + ResourceKey Key, ReOptMaterializationUnitState &State) { + std::unique_lock Lock(Mutex); + MUResources[Key].insert(State.getID()); +} + +Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) { + std::unique_lock Lock(Mutex); + for (auto MUID : MUResources[K]) + MUStates.erase(MUID); + + MUResources.erase(K); + return Error::success(); +} + +void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK, + ResourceKey SrcK) { + std::unique_lock Lock(Mutex); + MUResources[DstK].insert(MUResources[SrcK].begin(), MUResources[SrcK].end()); + MUResources.erase(SrcK); +} diff --git a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt index 98c86d80730249..a2bbb10039c9a0 100644 --- a/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/llvm/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -44,6 +44,7 @@ add_llvm_unittest(OrcJITTests ThreadSafeModuleTest.cpp WrapperFunctionUtilsTest.cpp JITLinkRedirectionManagerTest.cpp + ReOptimizeLayerTest.cpp EXPORT_SYMBOLS ) diff --git a/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp index 0f87c1b7433238..170637d78d292c 100644 --- a/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/JITLinkRedirectionManagerTest.cpp @@ -38,7 +38,7 @@ class JITLinkRedirectionManagerTest : public testing::Test { nullptr, nullptr, JTMB->getTargetTriple().getTriple())); JD = &ES->createBareJITDylib("main"); ObjLinkingLayer = std::make_unique( - *ES, std::make_unique(4096)); + *ES, std::make_unique(16384)); DL = std::make_unique( cantFail(JTMB->getDefaultDataLayoutForTarget())); } @@ -49,8 +49,7 @@ class JITLinkRedirectionManagerTest : public testing::Test { }; TEST_F(JITLinkRedirectionManagerTest, BasicRedirectionOperation) { - auto RM = - JITLinkRedirectableSymbolManager::Create(*ES, *ObjLinkingLayer, *JD); + auto RM = JITLinkRedirectableSymbolManager::Create(*ObjLinkingLayer, *JD); // Bail out if we can not create if (!RM) { consumeError(RM.takeError()); @@ -65,7 +64,7 @@ TEST_F(JITLinkRedirectionManagerTest, BasicRedirectionOperation) { // No dependencies registered, can't fail. cantFail( R->notifyResolved({{Target, {Addr, JITSymbolFlags::Exported}}})); - cantFail(R->notifyEmitted()); + cantFail(R->notifyEmitted({})); }))); return cantFail(ES->lookup({JD}, TargetName)); }; diff --git a/llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp new file mode 100644 index 00000000000000..9f04784332f29a --- /dev/null +++ b/llvm/unittests/ExecutionEngine/Orc/ReOptimizeLayerTest.cpp @@ -0,0 +1,152 @@ +#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h" +#include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/IRPartitionLayer.h" +#include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/ObjectTransformLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/TargetParser/Host.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::jitlink; + +class ReOptimizeLayerTest : public testing::Test { +public: + ~ReOptimizeLayerTest() { + if (ES) + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + +protected: + void SetUp() override { + auto JTMB = JITTargetMachineBuilder::detectHost(); + // Bail out if we can not detect the host. + if (!JTMB) { + consumeError(JTMB.takeError()); + GTEST_SKIP(); + } + + auto EPC = SelfExecutorProcessControl::Create(); + if (!EPC) { + consumeError(EPC.takeError()); + GTEST_SKIP(); + } + ES = std::make_unique(std::move(*EPC)); + JD = &ES->createBareJITDylib("main"); + ObjLinkingLayer = std::make_unique( + *ES, std::make_unique(16384)); + DL = std::make_unique( + cantFail(JTMB->getDefaultDataLayoutForTarget())); + + auto TM = JTMB->createTargetMachine(); + if (!TM) { + consumeError(TM.takeError()); + GTEST_SKIP(); + } + auto CompileFunction = + std::make_unique(std::move(*TM)); + CompileLayer = std::make_unique(*ES, *ObjLinkingLayer, + std::move(CompileFunction)); + } + + Error addIRModule(ResourceTrackerSP RT, ThreadSafeModule TSM) { + assert(TSM && "Can not add null module"); + + TSM.withModuleDo([&](Module &M) { M.setDataLayout(*DL); }); + + return ROLayer->add(std::move(RT), std::move(TSM)); + } + + JITDylib *JD{nullptr}; + std::unique_ptr ES; + std::unique_ptr ObjLinkingLayer; + std::unique_ptr CompileLayer; + std::unique_ptr ROLayer; + std::unique_ptr DL; +}; + +static Function *createRetFunction(Module *M, StringRef Name, + uint32_t ReturnCode) { + Function *Result = Function::Create( + FunctionType::get(Type::getInt32Ty(M->getContext()), {}, false), + GlobalValue::ExternalLinkage, Name, M); + + BasicBlock *BB = BasicBlock::Create(M->getContext(), Name, Result); + IRBuilder<> Builder(M->getContext()); + Builder.SetInsertPoint(BB); + + Value *RetValue = ConstantInt::get(M->getContext(), APInt(32, ReturnCode)); + Builder.CreateRet(RetValue); + return Result; +} + +TEST_F(ReOptimizeLayerTest, BasicReOptimization) { + MangleAndInterner Mangle(*ES, *DL); + + auto &EPC = ES->getExecutorProcessControl(); + EXPECT_THAT_ERROR(JD->define(absoluteSymbols( + {{Mangle("__orc_rt_jit_dispatch"), + {EPC.getJITDispatchInfo().JITDispatchFunction, + JITSymbolFlags::Exported}}, + {Mangle("__orc_rt_jit_dispatch_ctx"), + {EPC.getJITDispatchInfo().JITDispatchContext, + JITSymbolFlags::Exported}}, + {Mangle("__orc_rt_reoptimize_tag"), + {ExecutorAddr(), JITSymbolFlags::Exported}}})), + Succeeded()); + + auto RM = JITLinkRedirectableSymbolManager::Create(*ObjLinkingLayer, *JD); + EXPECT_THAT_ERROR(RM.takeError(), Succeeded()); + + ROLayer = std::make_unique(*ES, *DL, *CompileLayer, **RM); + ROLayer->setReoptimizeFunc( + [&](ReOptimizeLayer &Parent, + ReOptimizeLayer::ReOptMaterializationUnitID MUID, unsigned CurVerison, + ResourceTrackerSP OldRT, ThreadSafeModule &TSM) { + TSM.withModuleDo([&](Module &M) { + for (auto &F : M) { + if (F.isDeclaration()) + continue; + for (auto &B : F) { + for (auto &I : B) { + if (ReturnInst *Ret = dyn_cast(&I)) { + Value *RetValue = + ConstantInt::get(M.getContext(), APInt(32, 53)); + Ret->setOperand(0, RetValue); + } + } + } + } + }); + return Error::success(); + }); + EXPECT_THAT_ERROR(ROLayer->reigsterRuntimeFunctions(*JD), Succeeded()); + + ThreadSafeContext Ctx(std::make_unique()); + auto M = std::make_unique("
", *Ctx.getContext()); + M->setTargetTriple(sys::getProcessTriple()); + + (void)createRetFunction(M.get(), "main", 42); + + EXPECT_THAT_ERROR(addIRModule(JD->getDefaultResourceTracker(), + ThreadSafeModule(std::move(M), std::move(Ctx))), + Succeeded()); + + auto Result = cantFail(ES->lookup({JD}, Mangle("main"))); + auto FuncPtr = Result.getAddress().toPtr(); + for (size_t I = 0; I <= ReOptimizeLayer::CallCountThreshold; I++) + EXPECT_EQ(FuncPtr(), 42); + EXPECT_EQ(FuncPtr(), 53); +}