diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index e37bce3118bcb2..985ca1532e0149 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -901,6 +901,12 @@ class TargetTransformInfo { bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx) const; + /// Identifies if the vector form of the intrinsic is overloaded on the type + /// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is + /// -1. + bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) const; + /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. @@ -1993,6 +1999,8 @@ class TargetTransformInfo::Concept { virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0; virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx) = 0; + virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) = 0; virtual InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, @@ -2569,6 +2577,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx); } + bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) override { + return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx); + } + InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 72038c090b7922..38aba183f6a173 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -396,6 +396,11 @@ class TargetTransformInfoImplBase { return false; } + bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) const { + return ScalarOpdIdx == -1; + } + InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h index 467d5932cacf91..c1016dd7bdddbd 100644 --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -152,7 +152,10 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, /// Identifies if the vector form of the intrinsic is overloaded on the type of /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. -bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx); +/// \p TTI is used to consider target specific intrinsics, if no target specific +/// intrinsics will be considered then it is appropriate to pass in nullptr. +bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx, + const TargetTransformInfo *TTI); /// Identifies if the vector form of the intrinsic that returns a struct is /// overloaded at the struct element index \p RetIdx. diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 3b098c42f2741c..b3583e2819ee4c 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -801,6 +801,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { return false; } + bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) const { + return ScalarOpdIdx == -1; + } + /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert, bool Extract, diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 174e5e87abe538..1fb2b9836de0cc 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -615,6 +615,11 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg( return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx); } +bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg( + Intrinsic::ID ID, int ScalarOpdIdx) const { + return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx); +} + InstructionCost TargetTransformInfo::getScalarizationOverhead( VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind) const { diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp index 15e325a0fffca5..1789671276ffaf 100644 --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -133,10 +133,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, } } -bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, - int OpdIdx) { +bool llvm::isVectorIntrinsicWithOverloadTypeAtArg( + Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) { assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!"); + if (TTI && Intrinsic::isTargetIntrinsic(ID)) + return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx); + switch (ID) { case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp index 7f3c5cf6cb4436..8d457f58e6eede 100644 --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -110,7 +110,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, // OloadTys collects types used in scalar intrinsic overload name. SmallVector OloadTys; - if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1)) + if (!RetTy->isVoidTy() && + isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr)) OloadTys.push_back(ScalarRetTy); // Compute the argument types of the corresponding scalar call and check that @@ -118,7 +119,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, SmallVector ScalarArgTypes; for (auto Arg : enumerate(II->args())) { auto *ArgTy = Arg.value()->getType(); - bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index()); + bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(), + /*TTI=*/nullptr); if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) { ScalarArgTypes.push_back(ArgTy); if (IsOloadTy) diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index b0436a39423405..182cdaa4e9a7d7 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, } } +bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx) { + switch (ID) { + default: + return ScalarOpdIdx == -1; + } +} + bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( Intrinsic::ID ID) const { switch (ID) { diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h index 30b57ed97d6370..a18e4a28625756 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h @@ -37,6 +37,8 @@ class DirectXTTIImpl : public BasicTTIImplBase { bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const; bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx); + bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, + int ScalarOpdIdx); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 03d069c9fcb36d..3b701e6ca09761 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -727,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { SmallVector Tys; // Add return type if intrinsic is overloaded on it. - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI)) Tys.push_back(VS->SplitTy); if (AreAllVectorsOfMatchingSize) { @@ -767,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { } Scattered[I] = scatter(&CI, OpI, *OpVS); - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) { + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) { OverloadIdx[I] = Tys.size(); Tys.push_back(OpVS->SplitTy); } } else { ScalarOperands[I] = OpI; - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) Tys.push_back(OpI->getType()); } } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index fda6550a375480..2854c1462014f9 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7684,7 +7684,8 @@ DenseMap LoopVectorizationPlanner::executePlan( LLVM_DEBUG(BestVPlan.dump()); // Perform the actual loop transformation. - VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan); + VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV, + &BestVPlan); // 0. Generate SCEV-dependent code into the preheader, including TripCount, // before making any changes to the CFG. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index dd87d34d1f01a4..f13d0d80d382a4 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15655,7 +15655,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { SmallVector OpVecs; SmallVector TysForDecl; // Add return type if intrinsic is overloaded on it. - if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) + if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI)) TysForDecl.push_back(VecTy); auto *CEI = cast(VL0); for (unsigned I : seq(0, CI->arg_size())) { @@ -15670,7 +15670,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { It->second.first < DL->getTypeSizeInBits(CEI->getType())) ScalarArg = Builder.getFalse(); OpVecs.push_back(ScalarArg); - if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) TysForDecl.push_back(ScalarArg->getType()); continue; } @@ -15692,7 +15692,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) { } LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); - if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) + if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) TysForDecl.push_back(OpVec->getType()); } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index 8b1a4aeb88f81f..a24a86b4201c31 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -219,10 +219,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { return It; } -VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, +VPTransformState::VPTransformState(const TargetTransformInfo *TTI, + ElementCount VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, InnerLoopVectorizer *ILV, VPlan *Plan) - : VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan), + : TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan), LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {} Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) { diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index abfe97b4ab55b6..9ef85a7f7a7524 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -234,9 +234,11 @@ class VPLane { /// VPTransformState holds information passed down when "executing" a VPlan, /// needed for generating the output IR. struct VPTransformState { - VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI, - DominatorTree *DT, IRBuilderBase &Builder, + VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF, + LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder, InnerLoopVectorizer *ILV, VPlan *Plan); + /// Target Transform Info. + const TargetTransformInfo *TTI; /// The chosen Vectorization Factor of the loop being vectorized. ElementCount VF; diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index ef2ca9af7268d1..71aca3be9e5dcb 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -941,7 +941,7 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) { SmallVector TysForDecl; // Add return type if intrinsic is overloaded on it. - if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) + if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI)) TysForDecl.push_back(VectorType::get(getResultType(), State.VF)); SmallVector Args; for (const auto &I : enumerate(operands())) { @@ -952,7 +952,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) { Arg = State.get(I.value(), VPLane(0)); else Arg = State.get(I.value(), onlyFirstLaneUsed(I.value())); - if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index())) + if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(), + State.TTI)) TysForDecl.push_back(Arg->getType()); Args.push_back(Arg); }