Skip to content

Commit

Permalink
[ORC] Implement basic reoptimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
sunho committed Oct 11, 2024
1 parent 04af63b commit 188ede2
Show file tree
Hide file tree
Showing 11 changed files with 650 additions and 30 deletions.
6 changes: 3 additions & 3 deletions compiler-rt/lib/orc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

/// This macro should be used to define tags that will be associated with
/// handlers in the JIT process, and call can be used to define tags f
#define ORC_RT_JIT_DISPATCH_TAG(X) \
extern "C" char X; \
char X = 0;
#define ORC_RT_JIT_DISPATCH_TAG(X) \
ORC_RT_INTERFACE char X; \
char X = 0;

/// Opaque struct for external symbols.
struct __orc_rt_Opaque {};
Expand Down
1 change: 1 addition & 0 deletions compiler-rt/lib/orc/elfnix_platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using namespace orc_rt;
using namespace orc_rt::elfnix;

// Declare function tags for functions in the JIT process.
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_reoptimize_tag)
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_push_initializers_tag)
ORC_RT_JIT_DISPATCH_TAG(__orc_rt_elfnix_symbol_lookup_tag)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager,
public:
/// Create redirection manager that uses JITLink based implementaion.
static Expected<std::unique_ptr<RedirectableSymbolManager>>
Create(ExecutionSession &ES, ObjectLinkingLayer &ObjLinkingLayer,
JITDylib &JD) {
Create(ObjectLinkingLayer &ObjLinkingLayer, JITDylib &JD) {
Error Err = Error::success();
auto RM = std::unique_ptr<RedirectableSymbolManager>(
new JITLinkRedirectableSymbolManager(ES, ObjLinkingLayer, JD, Err));
new JITLinkRedirectableSymbolManager(ObjLinkingLayer, JD, Err));
if (Err)
return Err;
return std::move(RM);
Expand All @@ -53,38 +52,40 @@ class JITLinkRedirectableSymbolManager : public RedirectableSymbolManager,
constexpr static StringRef JumpStubTableName = "$IND_JUMP_";
constexpr static StringRef StubPtrTableName = "$__IND_JUMP_PTRS";

JITLinkRedirectableSymbolManager(ExecutionSession &ES,
ObjectLinkingLayer &ObjLinkingLayer,
JITLinkRedirectableSymbolManager(ObjectLinkingLayer &ObjLinkingLayer,
JITDylib &JD, Error &Err)
: ES(ES), ObjLinkingLayer(ObjLinkingLayer), JD(JD),
AnonymousPtrCreator(
jitlink::getAnonymousPointerCreator(ES.getTargetTriple())),
PtrJumpStubCreator(
jitlink::getPointerJumpStubCreator(ES.getTargetTriple())) {
: ObjLinkingLayer(ObjLinkingLayer), JD(JD),
AnonymousPtrCreator(jitlink::getAnonymousPointerCreator(
ObjLinkingLayer.getExecutionSession().getTargetTriple())),
PtrJumpStubCreator(jitlink::getPointerJumpStubCreator(
ObjLinkingLayer.getExecutionSession().getTargetTriple())) {
if (!AnonymousPtrCreator || !PtrJumpStubCreator)
Err = make_error<StringError>("Architecture not supported",
inconvertibleErrorCode());
if (Err)
return;
ES.registerResourceManager(*this);
ObjLinkingLayer.getExecutionSession().registerResourceManager(*this);
}

~JITLinkRedirectableSymbolManager() { ES.deregisterResourceManager(*this); }
~JITLinkRedirectableSymbolManager() {
ObjLinkingLayer.getExecutionSession().deregisterResourceManager(*this);
}

StringRef JumpStubSymbolName(unsigned I) {
return *ES.intern((JumpStubPrefix + Twine(I)).str());
return *ObjLinkingLayer.getExecutionSession().intern(
(JumpStubPrefix + Twine(I)).str());
}

StringRef StubPtrSymbolName(unsigned I) {
return *ES.intern((StubPtrPrefix + Twine(I)).str());
return *ObjLinkingLayer.getExecutionSession().intern(
(StubPtrPrefix + Twine(I)).str());
}

unsigned GetNumAvailableStubs() const { return AvailableStubs.size(); }

Error redirectInner(JITDylib &TargetJD, const SymbolAddrMap &NewDests);
Error grow(unsigned Need);

ExecutionSession &ES;
ObjectLinkingLayer &ObjLinkingLayer;
JITDylib &JD;
jitlink::AnonymousPointerCreator AnonymousPtrCreator;
Expand Down
181 changes: 181 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/ReOptimizeLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
//===- ReOptimizeLayer.h - Re-optimization layer interface ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Re-optimization layer interface.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_EXECUTIONENGINE_ORC_REOPTIMIZELAYER_H
#define LLVM_EXECUTIONENGINE_ORC_REOPTIMIZELAYER_H

#include "llvm/ExecutionEngine/Orc/Core.h"
#include "llvm/ExecutionEngine/Orc/Layer.h"
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "llvm/ExecutionEngine/Orc/RedirectionManager.h"
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"

namespace llvm {
namespace orc {

class ReOptimizeLayer : public IRLayer, public ResourceManager {
public:
using ReOptMaterializationUnitID = uint64_t;

/// AddProfilerFunc will be called when ReOptimizeLayer emits the first
/// version of a materialization unit in order to inject profiling code and
/// reoptimization request code.
using AddProfilerFunc = unique_function<Error(
ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID,
unsigned CurVersion, ThreadSafeModule &TSM)>;

/// ReOptimizeFunc will be called when ReOptimizeLayer reoptimization of a
/// materialization unit was requested in order to reoptimize the IR module
/// based on profile data. OldRT is the ResourceTracker that tracks the old
/// function definitions. The OldRT must be kept alive until it can be
/// guaranteed that every invocation of the old function definitions has been
/// terminated.
using ReOptimizeFunc = unique_function<Error(
ReOptimizeLayer &Parent, ReOptMaterializationUnitID MUID,
unsigned CurVersion, ResourceTrackerSP OldRT, ThreadSafeModule &TSM)>;

ReOptimizeLayer(ExecutionSession &ES, DataLayout &DL, IRLayer &BaseLayer,
RedirectableSymbolManager &RM)
: IRLayer(ES, BaseLayer.getManglingOptions()), ES(ES), Mangle(ES, DL),
BaseLayer(BaseLayer), RSManager(RM), ReOptFunc(identity),
ProfilerFunc(reoptimizeIfCallFrequent) {}

void setReoptimizeFunc(ReOptimizeFunc ReOptFunc) {
this->ReOptFunc = std::move(ReOptFunc);
}

void setAddProfilerFunc(AddProfilerFunc ProfilerFunc) {
this->ProfilerFunc = std::move(ProfilerFunc);
}

/// Registers reoptimize runtime dispatch handlers to given PlatformJD. The
/// reoptimization request will not be handled if dispatch handler is not
/// registered by using this function.
Error reigsterRuntimeFunctions(JITDylib &PlatformJD);

/// Emits the given module. This should not be called by clients: it will be
/// called by the JIT when a definition added via the add method is requested.
void emit(std::unique_ptr<MaterializationResponsibility> R,
ThreadSafeModule TSM) override;

static const uint64_t CallCountThreshold = 10;

/// Basic AddProfilerFunc that reoptimizes the function when the call count
/// exceeds CallCountThreshold.
static Error reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
ReOptMaterializationUnitID MUID,
unsigned CurVersion,
ThreadSafeModule &TSM);

static Error identity(ReOptimizeLayer &Parent,
ReOptMaterializationUnitID MUID, unsigned CurVersion,
ResourceTrackerSP OldRT, ThreadSafeModule &TSM) {
return Error::success();
}

// Create IR reoptimize request fucntion call.
static void createReoptimizeCall(Module &M, Instruction &IP,
GlobalVariable *ArgBuffer);

Error handleRemoveResources(JITDylib &JD, ResourceKey K) override;
void handleTransferResources(JITDylib &JD, ResourceKey DstK,
ResourceKey SrcK) override;

private:
class ReOptMaterializationUnitState {
public:
ReOptMaterializationUnitState() = default;
ReOptMaterializationUnitState(ReOptMaterializationUnitID ID,
ThreadSafeModule TSM)
: ID(ID), TSM(std::move(TSM)) {}
ReOptMaterializationUnitState(ReOptMaterializationUnitState &&Other)
: ID(Other.ID), TSM(std::move(Other.TSM)), RT(std::move(Other.RT)),
Reoptimizing(std::move(Other.Reoptimizing)),
CurVersion(Other.CurVersion) {}

ReOptMaterializationUnitID getID() { return ID; }

const ThreadSafeModule &getThreadSafeModule() { return TSM; }

ResourceTrackerSP getResourceTracker() {
std::unique_lock<std::mutex> Lock(Mutex);
return RT;
}

void setResourceTracker(ResourceTrackerSP RT) {
std::unique_lock<std::mutex> Lock(Mutex);
this->RT = RT;
}

uint32_t getCurVersion() {
std::unique_lock<std::mutex> Lock(Mutex);
return CurVersion;
}

bool tryStartReoptimize();
void reoptimizeSucceeded();
void reoptimizeFailed();

private:
std::mutex Mutex;
ReOptMaterializationUnitID ID;
ThreadSafeModule TSM;
ResourceTrackerSP RT;
bool Reoptimizing = false;
uint32_t CurVersion = 0;
};

using SPSReoptimizeArgList =
shared::SPSArgList<ReOptMaterializationUnitID, uint32_t>;
using SendErrorFn = unique_function<void(Error)>;

Expected<SymbolMap> emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
uint32_t Version, JITDylib &JD,
ThreadSafeModule TSM);

void rt_reoptimize(SendErrorFn SendResult, ReOptMaterializationUnitID MUID,
uint32_t CurVersion);

static Expected<Constant *>
createReoptimizeArgBuffer(Module &M, ReOptMaterializationUnitID MUID,
uint32_t CurVersion);

ReOptMaterializationUnitState &
createMaterializationUnitState(const ThreadSafeModule &TSM);

void
registerMaterializationUnitResource(ResourceKey Key,
ReOptMaterializationUnitState &State);

ReOptMaterializationUnitState &
getMaterializationUnitState(ReOptMaterializationUnitID MUID);

ExecutionSession &ES;
MangleAndInterner Mangle;
IRLayer &BaseLayer;
RedirectableSymbolManager &RSManager;

ReOptimizeFunc ReOptFunc;
AddProfilerFunc ProfilerFunc;

std::mutex Mutex;
std::map<ReOptMaterializationUnitID, ReOptMaterializationUnitState> MUStates;
DenseMap<ResourceKey, DenseSet<ReOptMaterializationUnitID>> MUResources;
ReOptMaterializationUnitID NextID = 1;
};

} // namespace orc
} // namespace llvm

#endif
2 changes: 2 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/RedirectionManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class RedirectableSymbolManager : public RedirectionManager {
const SymbolMap &InitialDests) = 0;
};

/// RedirectableMaterializationUnit materializes redirectable symbol
/// by invoking RedirectableSymbolManager::emitRedirectableSymbols
class RedirectableMaterializationUnit : public MaterializationUnit {
public:
RedirectableMaterializationUnit(RedirectableSymbolManager &RM,
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/ExecutionEngine/Orc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_llvm_component_library(LLVMOrcJIT
ThreadSafeModule.cpp
RedirectionManager.cpp
JITLinkRedirectableSymbolManager.cpp
ReOptimizeLayer.cpp
ADDITIONAL_HEADER_DIRS
${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc

Expand Down
19 changes: 11 additions & 8 deletions llvm/lib/ExecutionEngine/Orc/JITLinkRedirectableSymbolManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using namespace llvm::orc;
void JITLinkRedirectableSymbolManager::emitRedirectableSymbols(
std::unique_ptr<MaterializationResponsibility> R,
const SymbolAddrMap &InitialDests) {
auto &ES = ObjLinkingLayer.getExecutionSession();
std::unique_lock<std::mutex> Lock(Mutex);
if (GetNumAvailableStubs() < InitialDests.size())
if (auto Err = grow(InitialDests.size() - GetNumAvailableStubs())) {
Expand All @@ -37,21 +38,21 @@ void JITLinkRedirectableSymbolManager::emitRedirectableSymbols(
R->failMaterialization();
return;
}
dbgs() << *K << "\n";
SymbolToStubs[&TargetJD][K] = StubID;
NewSymbolDefs[K] = JumpStubs[StubID];
NewSymbolDefs[K].setFlags(V.getFlags());
Symbols.push_back(K);
AvailableStubs.pop_back();
}

if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) {
// FIXME: when this fails we can return stubs to the pool
if (auto Err = redirectInner(TargetJD, InitialDests)) {
ES.reportError(std::move(Err));
R->failMaterialization();
return;
}

if (auto Err = redirectInner(TargetJD, InitialDests)) {
if (auto Err = R->replace(absoluteSymbols(NewSymbolDefs))) {
ES.reportError(std::move(Err));
R->failMaterialization();
return;
Expand Down Expand Up @@ -85,10 +86,10 @@ Error JITLinkRedirectableSymbolManager::redirectInner(
StubHandle StubID = SymbolToStubs[&TargetJD].at(K);
PtrWrites.push_back({StubPointers[StubID].getAddress(), V.getAddress()});
}
if (auto Err = ES.getExecutorProcessControl().getMemoryAccess().writePointers(
PtrWrites))
return Err;
return Error::success();
return ObjLinkingLayer.getExecutionSession()
.getExecutorProcessControl()
.getMemoryAccess()
.writePointers(PtrWrites);
}

Error JITLinkRedirectableSymbolManager::grow(unsigned Need) {
Expand All @@ -103,16 +104,18 @@ Error JITLinkRedirectableSymbolManager::grow(unsigned Need) {
SymbolLookupSet LookupSymbols;
DenseMap<SymbolStringPtr, ExecutorSymbolDef *> NewDefsMap;

auto &ES = ObjLinkingLayer.getExecutionSession();
Triple TT = ES.getTargetTriple();
auto G = std::make_unique<jitlink::LinkGraph>(
"<INDIRECT STUBS>", TT, TT.isArch64Bit() ? 8 : 4,
TT.isLittleEndian() ? support::little : support::big,
TT.isLittleEndian() ? endianness::little : endianness::big,
jitlink::getGenericEdgeKindName);
auto &PointerSection =
G->createSection(StubPtrTableName, MemProt::Write | MemProt::Read);
auto &StubsSection =
G->createSection(JumpStubTableName, MemProt::Exec | MemProt::Read);

// FIXME: We can batch the stubs into one block and use address to access them
for (size_t I = OldSize; I < NewSize; I++) {
auto Pointer = AnonymousPtrCreator(*G, PointerSection, nullptr, 0);
if (auto Err = Pointer.takeError())
Expand Down
Loading

0 comments on commit 188ede2

Please sign in to comment.