Skip to content

Commit

Permalink
Vectorizer Update fmul instruction vectorized
Browse files Browse the repository at this point in the history
Vectorizer now can support vector emission of fmul instructions.
Implemented for triton flash attention kernel.
  • Loading branch information
esukhov authored and igcbot committed Dec 9, 2024
1 parent 7bbc6f6 commit 0b8394d
Show file tree
Hide file tree
Showing 8 changed files with 954 additions and 227 deletions.
35 changes: 35 additions & 0 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4330,6 +4330,7 @@ void EmitPass::EmitGenericPointersCmp(llvm::Instruction* inst,

void EmitPass::BinaryUnary(llvm::Instruction* inst, const SSource source[2], const DstModifier& modifier)
{

switch (inst->getOpcode())
{
case Instruction::FCmp:
Expand Down Expand Up @@ -4361,6 +4362,9 @@ void EmitPass::BinaryUnary(llvm::Instruction* inst, const SSource source[2], con
case Instruction::Mul:
Mul(source, modifier);
break;
case Instruction::FMul:
Mul(source, modifier);
break;
case Instruction::Call:
EmitAluIntrinsic(cast<CallInst>(inst), source, modifier);
break;
Expand Down Expand Up @@ -4572,6 +4576,15 @@ void EmitPass::Mul64(CVariable* dst, CVariable* src[2], SIMDMode simdMode, bool
m_encoder->Push();
}

static unsigned int getVectorSize(Instruction *I) {
IGCLLVM::FixedVectorType *VecType =
llvm::dyn_cast<IGCLLVM::FixedVectorType>(I->getType());
if (!VecType)
return 0;
unsigned int NumElements = VecType->getNumElements();
return NumElements;
}

void EmitPass::Mul(const SSource sources[2], const DstModifier& modifier)
{
CVariable* src[2];
Expand All @@ -4580,6 +4593,28 @@ void EmitPass::Mul(const SSource sources[2], const DstModifier& modifier)
src[i] = GetSrcVariable(sources[i]);
}

if (IGC_IS_FLAG_ENABLED(EnableVectorEmitter) && sources[0].value->getType()->isVectorTy() && sources[1].value->getType()->isVectorTy()) {

unsigned int VectorSize = 0;
if (llvm::isa<Instruction>(sources[0].value))
VectorSize = getVectorSize(llvm::cast<Instruction>(sources[0].value));

for (unsigned int i = 0; i < VectorSize; ++i) {
SetSourceModifiers(0, sources[0]);
SetSourceModifiers(1, sources[1]);

if (src[0]->IsUniform()) { m_encoder->SetSrcSubReg(0, i); }
else m_encoder->SetSrcSubVar(0, i);
if (src[1]->IsUniform()) { m_encoder->SetSrcSubReg(1, i); }
else m_encoder->SetSrcSubVar(1, i);

m_encoder->SetDstSubVar(i);
m_encoder->Mul(m_destination, src[0], src[1]);
m_encoder->Push();
}
return;
}

// Only i64 muls need special handling, otherwise go back to standard flow
VISA_Type srcType = src[0]->GetType();
if (srcType != ISA_TYPE_Q && srcType != ISA_TYPE_UQ)
Expand Down
Loading

0 comments on commit 0b8394d

Please sign in to comment.