Skip to content

Commit

Permalink
[NFC][VectorUtils][TargetTransformInfo] Add `isVectorIntrinsicWithOve…
Browse files Browse the repository at this point in the history
…rloadTypeAtArg` api (llvm#114849)

This changes allows target intrinsics to specify and overwrite overloaded types.

- Updates `ReplaceWithVecLib` to not provide TTI as there most probably won't be a use-case
- Updates `SLPVectorizer` to use available TTI
- Updates `VPTransformState` to pass down TTI
- Updates `VPlanRecipe` to use passed-down TTI

This change will let us add scalarization for `asdouble`:  llvm#114847
  • Loading branch information
inbelic authored Nov 21, 2024
1 parent f7497b1 commit 8663b87
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 18 deletions.
13 changes: 13 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,12 @@ class TargetTransformInfo {
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) const;

/// Identifies if the vector form of the intrinsic is overloaded on the type
/// of the operand at index \p OpdIdx, or on the return type if \p OpdIdx is
/// -1.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const;

/// Estimate the overhead of scalarizing an instruction. Insert and Extract
/// are set if the demanded result elements need to be inserted and/or
/// extracted from vectors.
Expand Down Expand Up @@ -1993,6 +1999,8 @@ class TargetTransformInfo::Concept {
virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0;
virtual bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx) = 0;
virtual bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) = 0;
virtual InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
Expand Down Expand Up @@ -2569,6 +2577,11 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) override {
return Impl.isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ class TargetTransformInfoImplBase {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
}

InstructionCost getScalarizationOverhead(VectorType *Ty,
const APInt &DemandedElts,
bool Insert, bool Extract,
Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/Analysis/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,

/// Identifies if the vector form of the intrinsic is overloaded on the type of
/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx);
/// \p TTI is used to consider target specific intrinsics, if no target specific
/// intrinsics will be considered then it is appropriate to pass in nullptr.
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
const TargetTransformInfo *TTI);

/// Identifies if the vector form of the intrinsic that returns a struct is
/// overloaded at the struct element index \p RetIdx.
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,11 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return false;
}

bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) const {
return ScalarOpdIdx == -1;
}

/// Helper wrapper for the DemandedElts variant of getScalarizationOverhead.
InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert,
bool Extract,
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ bool TargetTransformInfo::isTargetIntrinsicWithScalarOpAtArg(
return TTIImpl->isTargetIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx);
}

bool TargetTransformInfo::isVectorIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int ScalarOpdIdx) const {
return TTIImpl->isVectorIntrinsicWithOverloadTypeAtArg(ID, ScalarOpdIdx);
}

InstructionCost TargetTransformInfo::getScalarizationOverhead(
VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
TTI::TargetCostKind CostKind) const {
Expand Down
7 changes: 5 additions & 2 deletions llvm/lib/Analysis/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ bool llvm::isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
}
}

bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) {
bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(
Intrinsic::ID ID, int OpdIdx, const TargetTransformInfo *TTI) {
assert(ID != Intrinsic::not_intrinsic && "Not an intrinsic!");

if (TTI && Intrinsic::isTargetIntrinsic(ID))
return TTI->isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx);

switch (ID) {
case Intrinsic::fptosi_sat:
case Intrinsic::fptoui_sat:
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,17 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,

// OloadTys collects types used in scalar intrinsic overload name.
SmallVector<Type *, 3> OloadTys;
if (!RetTy->isVoidTy() && isVectorIntrinsicWithOverloadTypeAtArg(IID, -1))
if (!RetTy->isVoidTy() &&
isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
OloadTys.push_back(ScalarRetTy);

// Compute the argument types of the corresponding scalar call and check that
// all vector operands match the previously found EC.
SmallVector<Type *, 8> ScalarArgTypes;
for (auto Arg : enumerate(II->args())) {
auto *ArgTy = Arg.value()->getType();
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index());
bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
/*TTI=*/nullptr);
if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
ScalarArgTypes.push_back(ArgTy);
if (IsOloadTy)
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ bool DirectXTTIImpl::isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
}
}

bool DirectXTTIImpl::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx) {
switch (ID) {
default:
return ScalarOpdIdx == -1;
}
}

bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
Intrinsic::ID ID) const {
switch (ID) {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DirectXTTIImpl : public BasicTTIImplBase<DirectXTTIImpl> {
bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const;
bool isTargetIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
unsigned ScalarOpdIdx);
bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int ScalarOpdIdx);
};
} // namespace llvm

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/Scalar/Scalarizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {

SmallVector<llvm::Type *, 3> Tys;
// Add return type if intrinsic is overloaded on it.
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
Tys.push_back(VS->SplitTy);

if (AreAllVectorsOfMatchingSize) {
Expand Down Expand Up @@ -767,13 +767,13 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {
}

Scattered[I] = scatter(&CI, OpI, *OpVS);
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I)) {
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI)) {
OverloadIdx[I] = Tys.size();
Tys.push_back(OpVS->SplitTy);
}
} else {
ScalarOperands[I] = OpI;
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
Tys.push_back(OpI->getType());
}
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7684,7 +7684,8 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
LLVM_DEBUG(BestVPlan.dump());

// Perform the actual loop transformation.
VPTransformState State(BestVF, BestUF, LI, DT, ILV.Builder, &ILV, &BestVPlan);
VPTransformState State(&TTI, BestVF, BestUF, LI, DT, ILV.Builder, &ILV,
&BestVPlan);

// 0. Generate SCEV-dependent code into the preheader, including TripCount,
// before making any changes to the CFG.
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15655,7 +15655,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
SmallVector<Value *> OpVecs;
SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1, TTI))
TysForDecl.push_back(VecTy);
auto *CEI = cast<CallInst>(VL0);
for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
Expand All @@ -15670,7 +15670,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
It->second.first < DL->getTypeSizeInBits(CEI->getType()))
ScalarArg = Builder.getFalse();
OpVecs.push_back(ScalarArg);
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
TysForDecl.push_back(ScalarArg->getType());
continue;
}
Expand All @@ -15692,7 +15692,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
}
LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
OpVecs.push_back(OpVec);
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I, TTI))
TysForDecl.push_back(OpVec->getType());
}

Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
return It;
}

VPTransformState::VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
VPTransformState::VPTransformState(const TargetTransformInfo *TTI,
ElementCount VF, unsigned UF, LoopInfo *LI,
DominatorTree *DT, IRBuilderBase &Builder,
InnerLoopVectorizer *ILV, VPlan *Plan)
: VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
: TTI(TTI), VF(VF), CFG(DT), LI(LI), Builder(Builder), ILV(ILV), Plan(Plan),
LVer(nullptr), TypeAnalysis(Plan->getCanonicalIV()->getScalarType()) {}

Value *VPTransformState::get(VPValue *Def, const VPLane &Lane) {
Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,11 @@ class VPLane {
/// VPTransformState holds information passed down when "executing" a VPlan,
/// needed for generating the output IR.
struct VPTransformState {
VPTransformState(ElementCount VF, unsigned UF, LoopInfo *LI,
DominatorTree *DT, IRBuilderBase &Builder,
VPTransformState(const TargetTransformInfo *TTI, ElementCount VF, unsigned UF,
LoopInfo *LI, DominatorTree *DT, IRBuilderBase &Builder,
InnerLoopVectorizer *ILV, VPlan *Plan);
/// Target Transform Info.
const TargetTransformInfo *TTI;

/// The chosen Vectorization Factor of the loop being vectorized.
ElementCount VF;
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {

SmallVector<Type *, 2> TysForDecl;
// Add return type if intrinsic is overloaded on it.
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1))
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1, State.TTI))
TysForDecl.push_back(VectorType::get(getResultType(), State.VF));
SmallVector<Value *, 4> Args;
for (const auto &I : enumerate(operands())) {
Expand All @@ -952,7 +952,8 @@ void VPWidenIntrinsicRecipe::execute(VPTransformState &State) {
Arg = State.get(I.value(), VPLane(0));
else
Arg = State.get(I.value(), onlyFirstLaneUsed(I.value()));
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index(),
State.TTI))
TysForDecl.push_back(Arg->getType());
Args.push_back(Arg);
}
Expand Down

0 comments on commit 8663b87

Please sign in to comment.