Skip to content

Commit

Permalink
[Autobackout][FunctionalRegression]Revert of change: 375018b: Replaci…
Browse files Browse the repository at this point in the history
…ng usages of getNonOpaquePtrEltTy in AdaptorCommon - part 1

This change set is to prepare for removing dependencies on references to non-opaque pointers.
  • Loading branch information
sys-igc authored and igcbot committed Nov 14, 2024
1 parent 86b53a0 commit bbdf8df
Showing 1 changed file with 39 additions and 86 deletions.
125 changes: 39 additions & 86 deletions IGC/AdaptorCommon/LegalizeFunctionSignatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,15 @@ inline bool isLegalSignatureType(const Module& M, Type* ty, bool isStackCall)
}

// Check if a struct pointer argument is promotable to pass-by-value
inline bool isPromotableStructType(const Module& M, const Argument* arg, bool isStackCall, bool isReturnValue = false)
inline bool isPromotableStructType(const Module& M, const Type* ty, bool isStackCall, bool isReturnValue = false)
{
if (IGC_IS_FLAG_DISABLED(EnableByValStructArgPromotion))
return false;

const unsigned int maxSize = isStackCall ? MAX_STRUCT_SIZE_IN_BITS : MAX_SUBROUTINE_STRUCT_SIZE_IN_BITS;
llvm::Type* structType = nullptr;
if (arg->getType()->isPointerTy())
if (ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy())
{
if (arg->hasStructRetAttr() && arg->getParamStructRetType()->isStructTy())
{
structType = arg->getParamStructRetType();
}
else if (arg->hasByValAttr() && arg->getParamByValType()->isStructTy())
{
structType = arg->getParamByValType();
}
}
if (structType)
{
return isLegalStructType(M, structType, maxSize);
return isLegalStructType(M, IGCLLVM::getNonOpaquePtrEltTy(ty), maxSize);
}
return false;
}
Expand All @@ -205,33 +193,23 @@ inline bool FunctionHasPromotableSRetArg(const Module& M, const Function* F)
{
if (F->getReturnType()->isVoidTy() &&
!F->arg_empty() &&
isPromotableStructType(M, F->arg_begin(), F->hasFnAttribute("visaStackCall"), true))
F->arg_begin()->hasStructRetAttr() &&
isPromotableStructType(M, F->arg_begin()->getType(), F->hasFnAttribute("visaStackCall"), true))
{
return true;
}
return false;
}

// Promotes struct pointer to struct type
inline StructType* PromotedStructValueType(const Module& M, const Argument* arg)
inline Type* PromotedStructValueType(const Module& M, const Type* ty)
{
if (arg->getType()->isPointerTy())
{
if (arg->hasStructRetAttr() && arg->getParamStructRetType()->isStructTy())
{
return cast<StructType>(arg->getParamStructRetType());
}
else if (arg->hasByValAttr() && arg->getParamByValType()->isStructTy())
{
return cast<StructType>(arg->getParamByValType());
}
}
IGC_ASSERT_MESSAGE(0, "Not implemented case");
return nullptr;
IGC_ASSERT(ty->isPointerTy() && IGCLLVM::getNonOpaquePtrEltTy(ty)->isStructTy());
return cast<StructType>(IGCLLVM::getNonOpaquePtrEltTy(ty));
}

// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaInst* strPtr)
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Value* strPtr)
{
IGC_ASSERT(strPtr->getType()->isPointerTy());
IGC_ASSERT(strVal->getType()->isStructTy());
Expand All @@ -240,45 +218,12 @@ inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, AllocaIn
for (unsigned i = 0; i < sTy->getNumElements(); i++)
{
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getAllocatedType(), strPtr, indices);
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
Value* element = builder.CreateExtractValue(strVal, i);
builder.CreateStore(element, elementPtr);
}
}

// BE does not handle struct load/store, so instead store each element of the struct value to the GEP of the struct pointer
inline void StoreToStruct(IGCLLVM::IRBuilder<>& builder, Value* strVal, Argument* strPtr)
{
IGC_ASSERT(strPtr->getType()->isPointerTy());
IGC_ASSERT(strVal->getType()->isStructTy());
if (strPtr->hasStructRetAttr() && strPtr->getParamStructRetType()->isStructTy())
{
StructType* sTy = cast<StructType>(strVal->getType());
for (unsigned i = 0; i < sTy->getNumElements(); i++)
{
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getParamStructRetType(), strPtr, indices);
Value* element = builder.CreateExtractValue(strVal, i);
builder.CreateStore(element, elementPtr);
}
}
else if (strPtr->hasByValAttr() && strPtr->getParamByValType()->isStructTy())
{
StructType* sTy = cast<StructType>(strVal->getType());
for (unsigned i = 0; i < sTy->getNumElements(); i++)
{
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
Value* elementPtr = builder.CreateInBoundsGEP(strPtr->getParamByValType(), strPtr, indices);
Value* element = builder.CreateExtractValue(strVal, i);
builder.CreateStore(element, elementPtr);
}
}
else
{
IGC_ASSERT_MESSAGE(0, "Unsupported case: no information about the pointee type");
}
}

// BE does not handle struct load/store, so instead load each element from the GEP struct pointer and insert it into the struct value
inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type* ty)
{
Expand All @@ -290,7 +235,7 @@ inline Value* LoadFromStruct(IGCLLVM::IRBuilder<>& builder, Value* strPtr, Type*
for (unsigned i = 0; i < sTy->getNumElements(); i++)
{
Value* indices[] = { builder.getInt32(0), builder.getInt32(i) };
Value* elementPtr = builder.CreateInBoundsGEP(ty, strPtr, indices);
Value* elementPtr = builder.CreateInBoundsGEP(strPtr, indices);
Value* element = builder.CreateLoad(sTy->getElementType(i), elementPtr);
strVal = builder.CreateInsertValue(strVal, element, i);
}
Expand Down Expand Up @@ -363,10 +308,10 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
argTypes.push_back(LegalizedIntVectorType(M, ai->getType()));
}
else if (ai->hasByValAttr() &&
isPromotableStructType(M, ai, isStackCall))
isPromotableStructType(M, ai->getType(), isStackCall))
{
fixArgType = true;
argTypes.push_back(PromotedStructValueType(M, ai));
argTypes.push_back(PromotedStructValueType(M, ai->getType()));
}
else if (!isLegalSignatureType(M, ai->getType(), isStackCall))
{
Expand All @@ -384,7 +329,7 @@ void LegalizeFunctionSignatures::FixFunctionSignatures(Module& M)
// Clone function with new signature
Type* returnType =
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(M.getContext()) :
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()) :
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, pFunc->arg_begin()->getType()) :
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, pFunc->getReturnType()) :
pFunc->getReturnType();
FunctionType* signature = FunctionType::get(returnType, argTypes, false);
Expand Down Expand Up @@ -448,7 +393,7 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
if (OldArgIt == pFunc->arg_begin() && retTypeOption == ReturnOpt::RETURN_STRUCT)
{
// Create a temp alloca to map the old argument. This will be removed later by SROA.
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt);
tempAllocaForSRetPointerTy = PromotedStructValueType(M, OldArgIt->getType());
tempAllocaForSRetPointer = builder.CreateAlloca(tempAllocaForSRetPointerTy);
tempAllocaForSRetPointer = builder.CreateAddrSpaceCast(tempAllocaForSRetPointer, OldArgIt->getType());
VMap[&*OldArgIt] = tempAllocaForSRetPointer;
Expand All @@ -463,25 +408,24 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
VMap[&*OldArgIt] = trunc;
}
else if (OldArgIt->hasByValAttr() &&
isPromotableStructType(M, OldArgIt, isStackCall))
isPromotableStructType(M, OldArgIt->getType(), isStackCall))
{
AllocaInst* newArgPtr = builder.CreateAlloca(OldArgIt->getParamByValType());
// remove "byval" attrib since it is now pass-by-value
NewArgIt->removeAttr(llvm::Attribute::ByVal);
Value* newArgPtr = builder.CreateAlloca(NewArgIt->getType());
StoreToStruct(builder, &*NewArgIt, newArgPtr);
// cast back to original addrspace
IGC_ASSERT(OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GENERIC ||
OldArgIt->getType()->getPointerAddressSpace() == ADDRESS_SPACE_PRIVATE);
llvm::Value* castedNewArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
VMap[&*OldArgIt] = castedNewArgPtr;
newArgPtr = builder.CreateAddrSpaceCast(newArgPtr, OldArgIt->getType());
VMap[&*OldArgIt] = newArgPtr;
}
else if (!isLegalSignatureType(M, OldArgIt->getType(), isStackCall))
{
// Load from pointer arg
Value* load = builder.CreateLoad(OldArgIt->getType(), &*NewArgIt);
Value* load = builder.CreateLoad(&*NewArgIt);
VMap[&*OldArgIt] = load;
llvm::Attribute byValAttr = llvm::Attribute::getWithByValType(M.getContext(), OldArgIt->getType());
NewArgIt->addAttr(byValAttr);
ArgByVal.push_back(&*NewArgIt);
}
else
{
Expand All @@ -500,13 +444,21 @@ void LegalizeFunctionSignatures::FixFunctionBody(Module& M)
builder.CreateBr(ClonedEntryBB);
MergeBlockIntoPredecessor(ClonedEntryBB);

// Loop through new args and add 'byval' attributes
for (auto arg : ArgByVal)
{
arg->addAttr(llvm::Attribute::getWithByValType(M.getContext(),
IGCLLVM::getNonOpaquePtrEltTy(arg->getType())));
}

// Now fix the return values
if (retTypeOption == ReturnOpt::RETURN_BY_REF)
{
// Add the 'noalias' and 'sret' attribute to arg0
auto retArg = pNewFunc->arg_begin();
retArg->addAttr(llvm::Attribute::NoAlias);
retArg->addAttr(llvm::Attribute::getWithStructRetType(M.getContext(), pFunc->getReturnType()));
retArg->addAttr(llvm::Attribute::getWithStructRetType(
M.getContext(), IGCLLVM::getNonOpaquePtrEltTy(retArg->getType())));

// Loop through all return instructions and store the old return value into the arg0 pointer
const auto ptrSize = DL.getPointerSize();
Expand Down Expand Up @@ -625,7 +577,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
if (callInst->getType()->isVoidTy() &&
IGCLLVM::getNumArgOperands(callInst) > 0 &&
callInst->paramHasAttr(0, llvm::Attribute::StructRet) &&
isPromotableStructType(M, callInst->getCalledFunction()->getArg(0), isStackCall, true /* retval */))
isPromotableStructType(M, callInst->getArgOperand(0)->getType(), isStackCall, true /* retval */))
{
opNum++; // Skip the first call operand
retTypeOption = ReturnOpt::RETURN_STRUCT;
Expand All @@ -651,22 +603,23 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
// Check call operands if it needs to be replaced
for (; opNum < IGCLLVM::getNumArgOperands(callInst); opNum++)
{
Argument* arg = IGCLLVM::getArg(*calledFunc, opNum);
Value* arg = callInst->getArgOperand(opNum);
if (!isLegalIntVectorType(M, arg->getType()))
{
// extend the illegal int to a legal type
IGCLLVM::IRBuilder<> builder(callInst);
Value* extend = builder.CreateZExt(callInst->getOperand(opNum), LegalizedIntVectorType(M, arg->getType()));
Value* extend = builder.CreateZExt(arg, LegalizedIntVectorType(M, arg->getType()));
callArgs.push_back(extend);
ArgAttrVec.push_back(AttributeSet());
fixArgType = true;
}
else if (callInst->paramHasAttr(opNum, llvm::Attribute::ByVal) &&
isPromotableStructType(M, arg, isStackCall))
isPromotableStructType(M, arg->getType(), isStackCall))
{
// Map the new operand to the loaded value of the struct pointer
IGCLLVM::IRBuilder<> builder(callInst);
Value* newOp = LoadFromStruct(builder, callInst->getOperand(opNum), arg->getParamByValType());
Argument* callArg = IGCLLVM::getArg(*calledFunc, opNum);
Value* newOp = LoadFromStruct(builder, arg, callArg->getParamByValType());
callArgs.push_back(newOp);
ArgAttrVec.push_back(AttributeSet());
fixArgType = true;
Expand All @@ -676,7 +629,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
// Create and store operand as an alloca, then pass as argument
IGCLLVM::IRBuilder<> builder(callInst);
Value* allocaV = builder.CreateAlloca(arg->getType());
builder.CreateStore(callInst->getOperand(opNum), allocaV);
builder.CreateStore(arg, allocaV);
callArgs.push_back(allocaV);
auto byValAttr = llvm::Attribute::getWithByValType(M.getContext(), arg->getType());
auto argAttrs = AttributeSet::get(M.getContext(), { byValAttr });
Expand Down Expand Up @@ -706,7 +659,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
}
Type* retType =
retTypeOption == ReturnOpt::RETURN_BY_REF ? Type::getVoidTy(callInst->getContext()) :
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getFunction()->getArg(0)) :
retTypeOption == ReturnOpt::RETURN_STRUCT ? PromotedStructValueType(M, callInst->getArgOperand(0)->getType()) :
retTypeOption == ReturnOpt::RETURN_LEGAL_INT ? LegalizedIntVectorType(M, callInst->getType()) :
callInst->getType();
newFnTy = FunctionType::get(retType, argTypes, false);
Expand Down Expand Up @@ -737,7 +690,7 @@ void LegalizeFunctionSignatures::FixCallInstruction(Module& M, CallInst* callIns
else if (retTypeOption == ReturnOpt::RETURN_STRUCT)
{
// Store the struct value into the orginal pointer operand
StoreToStruct(builder, newCallInst, callInst->getCalledFunction()->getArg(0));
StoreToStruct(builder, newCallInst, callInst->getArgOperand(0));
}
else if (retTypeOption == ReturnOpt::RETURN_LEGAL_INT)
{
Expand Down

0 comments on commit bbdf8df

Please sign in to comment.