Skip to content

Commit

Permalink
Proof of concept for JITing functions in a pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
lhames committed Jan 16, 2024
1 parent c58bc24 commit 39fa521
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 32 deletions.
2 changes: 1 addition & 1 deletion libunwind/src/libunwind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void __unw_add_dynamic_eh_frame_section(unw_word_t eh_frame_start) {
CFI_Parser<LocalAddressSpace>::CIE_Info cieInfo;
CFI_Parser<LocalAddressSpace>::FDE_Info fdeInfo;
auto p = (LocalAddressSpace::pint_t)eh_frame_start;
while (true) {
while (LocalAddressSpace::sThisAddressSpace.get32(p)) {
if (CFI_Parser<LocalAddressSpace>::decodeFDE(
LocalAddressSpace::sThisAddressSpace, p, &fdeInfo, &cieInfo,
true) == NULL) {
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/ExecutionEngine/Orc/ThreadSafeModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ cloneToNewContext(const ThreadSafeModule &TSMW,
GVPredicate ShouldCloneDef = GVPredicate(),
GVModifier UpdateClonedDefSource = GVModifier());

/// Clones the given module on to a new context. This overload should only be
/// used if the caller knows that the given Module will not be concurrently
/// accessed during the clone.
ThreadSafeModule
cloneToNewContext(Module &M, GVPredicate ShouldCloneDef = GVPredicate(),
GVModifier UpdateClonedDefSource = GVModifier());

} // End namespace orc
} // End namespace llvm

Expand Down
22 changes: 22 additions & 0 deletions llvm/include/llvm/Transforms/Utils/MyJITPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===-------------------------- MyJITPass.h ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_UTILS_MYJITPASS_H
#define LLVM_TRANSFORMS_UTILS_MYJITPASS_H

#include "llvm/IR/PassManager.h"

namespace llvm {
class MyJITPass : public PassInfoMixin<MyJITPass> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_TRANSFORMS_UTILS_MYJITPASS_H
69 changes: 38 additions & 31 deletions llvm/lib/ExecutionEngine/Orc/ThreadSafeModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,52 @@ ThreadSafeModule cloneToNewContext(const ThreadSafeModule &TSM,
GVPredicate ShouldCloneDef,
GVModifier UpdateClonedDefSource) {
assert(TSM && "Can not clone null module");
return TSM.withModuleDo([&](Module &M) {
return cloneToNewContext(M, std::move(ShouldCloneDef),
std::move(UpdateClonedDefSource));
});
}

ThreadSafeModule
cloneToNewContext(Module &M, GVPredicate ShouldCloneDef,
GVModifier UpdateClonedDefSource) {

if (!ShouldCloneDef)
ShouldCloneDef = [](const GlobalValue &) { return true; };

SmallVector<char, 1> ClonedModuleBuffer;

return TSM.withModuleDo([&](Module &M) {
SmallVector<char, 1> ClonedModuleBuffer;

{
std::set<GlobalValue *> ClonedDefsInSrc;
ValueToValueMapTy VMap;
auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) {
if (ShouldCloneDef(*GV)) {
ClonedDefsInSrc.insert(const_cast<GlobalValue *>(GV));
return true;
}
return false;
});
{
std::set<GlobalValue *> ClonedDefsInSrc;
ValueToValueMapTy VMap;
auto Tmp = CloneModule(M, VMap, [&](const GlobalValue *GV) {
if (ShouldCloneDef(*GV)) {
ClonedDefsInSrc.insert(const_cast<GlobalValue *>(GV));
return true;
}
return false;
});

if (UpdateClonedDefSource)
for (auto *GV : ClonedDefsInSrc)
UpdateClonedDefSource(*GV);

if (UpdateClonedDefSource)
for (auto *GV : ClonedDefsInSrc)
UpdateClonedDefSource(*GV);
BitcodeWriter BCWriter(ClonedModuleBuffer);

BitcodeWriter BCWriter(ClonedModuleBuffer);
BCWriter.writeModule(*Tmp);
BCWriter.writeSymtab();
BCWriter.writeStrtab();
}

BCWriter.writeModule(*Tmp);
BCWriter.writeSymtab();
BCWriter.writeStrtab();
}
MemoryBufferRef ClonedModuleBufferRef(
StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()),
"cloned module buffer");
ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>());

MemoryBufferRef ClonedModuleBufferRef(
StringRef(ClonedModuleBuffer.data(), ClonedModuleBuffer.size()),
"cloned module buffer");
ThreadSafeContext NewTSCtx(std::make_unique<LLVMContext>());

auto ClonedModule = cantFail(
parseBitcodeFile(ClonedModuleBufferRef, *NewTSCtx.getContext()));
ClonedModule->setModuleIdentifier(M.getName());
return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx));
});
auto ClonedModule = cantFail(
parseBitcodeFile(ClonedModuleBufferRef, *NewTSCtx.getContext()));
ClonedModule->setModuleIdentifier(M.getName());
return ThreadSafeModule(std::move(ClonedModule), std::move(NewTSCtx));
}

} // end namespace orc
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@
#include "llvm/Transforms/Utils/Mem2Reg.h"
#include "llvm/Transforms/Utils/MetaRenamer.h"
#include "llvm/Transforms/Utils/MoveAutoInit.h"
#include "llvm/Transforms/Utils/MyJITPass.h"
#include "llvm/Transforms/Utils/NameAnonGlobals.h"
#include "llvm/Transforms/Utils/PredicateInfo.h"
#include "llvm/Transforms/Utils/RelLookupTableConverter.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ FUNCTION_PASS("memprof", MemProfilerPass())
FUNCTION_PASS("mergeicmps", MergeICmpsPass())
FUNCTION_PASS("mergereturn", UnifyFunctionExitNodesPass())
FUNCTION_PASS("move-auto-init", MoveAutoInitPass())
FUNCTION_PASS("myjitpass", MyJITPass())
FUNCTION_PASS("nary-reassociate", NaryReassociatePass())
FUNCTION_PASS("newgvn", NewGVNPass())
FUNCTION_PASS("no-op-function", NoOpFunctionPass())
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ add_llvm_component_library(LLVMTransformUtils
MisExpect.cpp
ModuleUtils.cpp
MoveAutoInit.cpp
MyJITPass.cpp
NameAnonGlobals.cpp
PredicateInfo.cpp
PromoteMemoryToRegister.cpp
Expand Down Expand Up @@ -96,6 +97,8 @@ add_llvm_component_library(LLVMTransformUtils
LINK_COMPONENTS
Analysis
Core
OrcJIT
OrcTargetProcess
Support
TargetParser
)
51 changes: 51 additions & 0 deletions llvm/lib/Transforms/Utils/MyJITPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===---------------------------- MyJITPass.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Utils/MyJITPass.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"

using namespace llvm;
using namespace llvm::orc;

PreservedAnalyses MyJITPass::run(Function &F,
FunctionAnalysisManager &AM) {
auto J = LLJITBuilder().create();
if (!J) {
errs() << "MyJITPass could not create JIT instance for "
<< F.getName() << ": " << toString(J.takeError()) << "\n";
return PreservedAnalyses::all();
}

auto I32I32FnTy = FunctionType::get(Type::getInt32Ty(F.getContext()),
{Type::getInt32Ty(F.getContext())}, false);
if (F.getFunctionType() != I32I32FnTy) {
errs() << "MyJITPass: Function " << F.getName()
<< " does not have required type. Skipping.\n";
return PreservedAnalyses::all();
}

auto FM = cloneToNewContext(*F.getParent(),
[&](const GlobalValue &GV) { return &GV == &F; });
if (auto Err = (*J)->addIRModule(std::move(FM))) {
errs() << "MyJITPass could not add extracted module for "
<< F.getName() << ": " << toString(std::move(Err)) << "\n";
return PreservedAnalyses::all();
}

auto FSym = (*J)->lookup(F.getName());
if (!FSym) {
errs() << "MyJITPass could not get JIT'd symbol for "
<< F.getName() << ": " << toString(FSym.takeError()) << "\n";
return PreservedAnalyses::all();
}

auto *JittedF = FSym->toPtr<int32_t(int32_t)>();
errs() << "Result: " << JittedF(42) << "\n";

return PreservedAnalyses::all();
}

0 comments on commit 39fa521

Please sign in to comment.