From 1136fba0f82885058f6d177f37d56ce9133616a0 Mon Sep 17 00:00:00 2001 From: TB Schardl Date: Tue, 1 Oct 2024 07:51:48 -0400 Subject: [PATCH] [CSI] Synthesize custom hooks for allocation functions, such as strdup, that can't be instrumented through normal means. --- .../llvm/Transforms/Instrumentation/CSI.h | 32 ++++ .../ComprehensiveStaticInstrumentation.cpp | 168 +++++++++++++++--- 2 files changed, 172 insertions(+), 28 deletions(-) diff --git a/llvm/include/llvm/Transforms/Instrumentation/CSI.h b/llvm/include/llvm/Transforms/Instrumentation/CSI.h index 60798746e2f6..8cbaa4f84786 100644 --- a/llvm/include/llvm/Transforms/Instrumentation/CSI.h +++ b/llvm/include/llvm/Transforms/Instrumentation/CSI.h @@ -1175,6 +1175,8 @@ struct CSIImpl { return IRB.getInt64(CsiUnknownId); } + Value *getCalleeFuncID(const Function *Callee, IRBuilder<> &IRB); + static bool spawnsTapirLoopBody(DetachInst *DI, LoopInfo &LI, TaskInfo &TI); static BasicBlock::iterator @@ -1189,6 +1191,32 @@ struct CSIImpl { /// Finalize the CSI pass. void finalizeCsi(); + FunctionCallee getHookFunction(StringRef Name, FunctionType *FnTy, + AttributeList AL) { + FunctionCallee Callee = M.getOrInsertFunction(Name, FnTy, AL); + if (Function *Fn = dyn_cast(Callee.getCallee())) { + Fn->setOnlyAccessesInaccessibleMemOrArgMem(); + } + return Callee; + } + template + FunctionCallee getHookFunction(StringRef Name, AttributeList AL, Type *RetTy, + ArgsTy... Args) { + FunctionCallee Callee = M.getOrInsertFunction(Name, AL, RetTy, Args...); + if (Function *Fn = dyn_cast(Callee.getCallee())) { + MemoryEffects ME = MemoryEffects::argMemOnly(ModRefInfo::Ref) | + MemoryEffects::inaccessibleMemOnly(ModRefInfo::ModRef); + Fn->setMemoryEffects(ME); + Fn->setDoesNotThrow(); + } + return Callee; + } + template + FunctionCallee getHookFunction(StringRef Name, Type *RetTy, + ArgsTy... Args) { + return getHookFunction(Name, AttributeList{}, RetTy, Args...); + } + /// Initialize FunctionCallees for the CSI hooks. /// @{ void initializeLoadStoreHooks(); @@ -1202,6 +1230,9 @@ struct CSIImpl { void initializeAllocFnHooks(); /// @} + FunctionCallee getOrInsertSynthesizedHook(StringRef Name, FunctionType *T, + AttributeList AL = AttributeList()); + static StructType *getUnitFedTableType(LLVMContext &C, PointerType *EntryPointerType); static Constant *fedTableToUnitFedTable(Module &M, @@ -1255,6 +1286,7 @@ struct CSIImpl { LoopInfo &LI); void instrumentSync(SyncInst *SI, unsigned SyncRegNum); void instrumentAlloca(Instruction *I, TaskInfo &TI); + bool instrumentAllocFnLibCall(Instruction *I, const TargetLibraryInfo *TLI); void instrumentAllocFn(Instruction *I, DominatorTree *DT, const TargetLibraryInfo *TLI); void instrumentFree(Instruction *I, const TargetLibraryInfo *TLI); diff --git a/llvm/lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp index b3ca25cdbc2e..f621b5305a25 100644 --- a/llvm/lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/ComprehensiveStaticInstrumentation.cpp @@ -746,6 +746,44 @@ void CSIImpl::initializeTapirHooks() { M.getOrInsertFunction("__csi_after_sync", RetType, IDType, SyncRegType); } +FunctionCallee CSIImpl::getOrInsertSynthesizedHook(StringRef Name, + FunctionType *T, + AttributeList AL) { + // If no bitcode file has been linked, then we cannot check if it contains a + // particular library hook. Simply return the hook. If the Cilksan library + // doesn't contain that hook, the linker will raise an error. + if (!LinkedBitcode) + return getHookFunction(Name, T, AL); + + // Check if the linked bitcode file contains the library hook. If it does, + // return that hook. + if (FunctionsInBitcode.contains(std::string(Name))) + return getHookFunction(Name, T, AL); + + // We did not find the library hook in the linked bitcode file. Synthesize a + // default version of the hook that simply calls __csi_default_libhook. + FunctionCallee NewHook = M.getOrInsertFunction(Name, T, AL); + Function *NewHookFn = cast(NewHook.getCallee()); + NewHookFn->setOnlyAccessesInaccessibleMemOrArgMem(); + NewHookFn->setDoesNotThrow(); + BasicBlock *Entry = BasicBlock::Create(M.getContext(), "entry", NewHookFn); + IRBuilder<> IRB(ReturnInst::Create(M.getContext(), Entry)); + + // Insert a call to the default library function hook + Type *IDType = IRB.getInt64Ty(); + FunctionType *DefaultHookTy = + FunctionType::get(IRB.getVoidTy(), + {/*call_id*/ + IDType, /*func_id*/ IDType, + /*MAAP_count*/ IRB.getInt8Ty()}, + /*isVarArg*/ false); + FunctionCallee DefaultHook = + M.getOrInsertFunction("__csi_default_libhook", DefaultHookTy); + IRB.CreateCall(DefaultHook, {NewHookFn->getArg(0), NewHookFn->getArg(1), + NewHookFn->getArg(2)}); + return NewHook; +} + // Prepare any calls in the CFG for instrumentation, e.g., by making sure any // call that can throw is modeled with an invoke. void CSIImpl::setupCalls(Function &F) { @@ -1257,16 +1295,40 @@ void CSIImpl::instrumentLoop(Loop &L, TaskInfo &TI, ScalarEvolution *SE) { } } +// Helper function to get the ID of a function being called. These IDs are +// stored in separate global variables in the program. This method will create +// a new global variable for the Callee's ID if necessary. +Value *CSIImpl::getCalleeFuncID(const Function *Callee, IRBuilder<> &IRB) { + if (!Callee) + // Unknown targets (i.e., indirect calls) are always unknown. + return IRB.getInt64(CsiCallsiteUnknownTargetId); + + std::string GVName = + CsiFuncIdVariablePrefix + Callee->getName().str(); + GlobalVariable *FuncIdGV = M.getNamedGlobal(GVName); + Type *FuncIdGVTy = IRB.getInt64Ty(); + if (!FuncIdGV) { + FuncIdGV = + dyn_cast(M.getOrInsertGlobal(GVName, FuncIdGVTy)); + assert(FuncIdGV); + FuncIdGV->setConstant(false); + if (Options.jitMode && !Callee->empty()) + FuncIdGV->setLinkage(Callee->getLinkage()); + else + FuncIdGV->setLinkage(GlobalValue::WeakAnyLinkage); + FuncIdGV->setInitializer(IRB.getInt64(CsiCallsiteUnknownTargetId)); + } + return IRB.CreateLoad(FuncIdGVTy, FuncIdGV); +} + void CSIImpl::instrumentCallsite(Instruction *I, DominatorTree *DT) { if (callsPlaceholderFunction(*I)) return; bool IsInvoke = isa(I); Function *Called = nullptr; - if (CallInst *CI = dyn_cast(I)) - Called = CI->getCalledFunction(); - else if (InvokeInst *II = dyn_cast(I)) - Called = II->getCalledFunction(); + if (CallBase *CB = dyn_cast(I)) + Called = CB->getCalledFunction(); bool ShouldInstrumentBefore = true; bool ShouldInstrumentAfter = true; @@ -1286,25 +1348,7 @@ void CSIImpl::instrumentCallsite(Instruction *I, DominatorTree *DT) { Value *DefaultID = getDefaultID(IRB); uint64_t LocalId = CallsiteFED.add(*I, Called ? Called->getName() : ""); Value *CallsiteId = CallsiteFED.localToGlobalId(LocalId, IRB); - Value *FuncId = nullptr; - GlobalVariable *FuncIdGV = nullptr; - if (Called) { - std::string GVName = CsiFuncIdVariablePrefix + Called->getName().str(); - Type *FuncIdGVTy = IRB.getInt64Ty(); - FuncIdGV = dyn_cast( - M.getOrInsertGlobal(GVName, FuncIdGVTy)); - assert(FuncIdGV); - FuncIdGV->setConstant(false); - if (Options.jitMode && !Called->empty()) - FuncIdGV->setLinkage(Called->getLinkage()); - else - FuncIdGV->setLinkage(GlobalValue::WeakAnyLinkage); - FuncIdGV->setInitializer(IRB.getInt64(CsiCallsiteUnknownTargetId)); - FuncId = IRB.CreateLoad(FuncIdGVTy, FuncIdGV); - } else { - // Unknown targets (i.e. indirect calls) are always unknown. - FuncId = IRB.getInt64(CsiCallsiteUnknownTargetId); - } + Value *FuncId = getCalleeFuncID(Called, IRB); assert(FuncId != NULL); CsiCallProperty Prop; Value *DefaultPropVal = Prop.getValue(IRB); @@ -1651,14 +1695,79 @@ bool CSIImpl::getAllocFnArgs(const Instruction *I, return true; } +bool CSIImpl::instrumentAllocFnLibCall(Instruction *I, + const TargetLibraryInfo *TLI) { + bool IsInvoke = isa(I); + CallBase *CB = dyn_cast(I); + if (!CB) + return false; + Function *Called = CB->getCalledFunction(); + + // Get the CSI IDs for this hook + IRBuilder<> IRB(I); + LLVMContext &Ctx = IRB.getContext(); + Value *DefaultID = getDefaultID(IRB); + uint64_t LocalId = AllocFnFED.add(*I); + Value *AllocFnId = AllocFnFED.localToGlobalId(LocalId, IRB); + Value *FuncId = getCalleeFuncID(Called, IRB); + assert(FuncId != NULL); + + CsiAllocFnProperty Prop; + Value *DefaultPropVal = Prop.getValue(IRB); + LibFunc AllocLibF; + TLI->getLibFunc(*Called, AllocLibF); + Prop.setAllocFnTy(static_cast(getAllocFnTy(AllocLibF))); + Value *PropVal = Prop.getValue(IRB); + Type *IDType = IRB.getInt64Ty(); + + // Synthesize the after hook for this function. + SmallVector AfterHookParamTys( + {IDType, /*callee func_id*/ IDType, CsiAllocFnProperty::getType(Ctx)}); + SmallVector AfterHookParamVals({AllocFnId, FuncId, PropVal}); + SmallVector AfterHookDefaultVals( + {DefaultID, DefaultID, DefaultPropVal}); + if (!Called->getReturnType()->isVoidTy()) { + AfterHookParamTys.push_back(Called->getReturnType()); + AfterHookParamVals.push_back(CB); + AfterHookDefaultVals.push_back( + Constant::getNullValue(Called->getReturnType())); + } + AfterHookParamTys.append(Called->getFunctionType()->param_begin(), + Called->getFunctionType()->param_end()); + AfterHookParamVals.append(CB->arg_begin(), CB->arg_end()); + for (Value *Arg : CB->args()) + AfterHookDefaultVals.push_back(Constant::getNullValue(Arg->getType())); + FunctionType *AfterHookTy = + FunctionType::get(IRB.getVoidTy(), AfterHookParamTys, Called->isVarArg()); + FunctionCallee AfterLibCallHook = getOrInsertSynthesizedHook( + ("__csi_alloc_" + Called->getName()).str(), AfterHookTy); + + // Insert the hook after the call. + BasicBlock::iterator Iter(I); + if (IsInvoke) { + // There are two "after" positions for invokes: the normal block and the + // exception block. + InvokeInst *II = cast(I); + insertHookCallInSuccessorBB(II->getNormalDest(), II->getParent(), + AfterLibCallHook, AfterHookParamVals, + AfterHookDefaultVals); + // Don't insert any instrumentation in the exception block. + } else { + // Simple call instruction; there is only one "after" position. + Iter++; + IRB.SetInsertPoint(&*Iter); + insertHookCall(&*Iter, AfterLibCallHook, AfterHookParamVals); + } + + return true; +} + void CSIImpl::instrumentAllocFn(Instruction *I, DominatorTree *DT, const TargetLibraryInfo *TLI) { bool IsInvoke = isa(I); Function *Called = nullptr; - if (CallInst *CI = dyn_cast(I)) - Called = CI->getCalledFunction(); - else if (InvokeInst *II = dyn_cast(I)) - Called = II->getCalledFunction(); + if (CallBase *CB = dyn_cast(I)) + Called = CB->getCalledFunction(); assert(Called && "Could not get called function for allocation fn."); @@ -1668,7 +1777,10 @@ void CSIImpl::instrumentAllocFn(Instruction *I, DominatorTree *DT, Value *AllocFnId = AllocFnFED.localToGlobalId(LocalId, IRB); SmallVector AllocFnArgs; - getAllocFnArgs(I, AllocFnArgs, IntptrTy, IRB.getPtrTy(), *TLI); + if (!getAllocFnArgs(I, AllocFnArgs, IntptrTy, IRB.getPtrTy(), *TLI)) { + instrumentAllocFnLibCall(I, TLI); + return; + } SmallVector DefaultAllocFnArgs({ /* Allocated size */ Constant::getNullValue(IntptrTy), /* Number of elements */ Constant::getNullValue(IntptrTy),