Skip to content

Commit

Permalink
WaveAllJointReduction Optimization
Browse files Browse the repository at this point in the history
Merge multiple consecutive WaveAll operations into a joint reduction
tree
  • Loading branch information
bowenxue-intel authored and igcbot committed Dec 3, 2024
1 parent 1037a76 commit be108bd
Show file tree
Hide file tree
Showing 13 changed files with 888 additions and 7 deletions.
186 changes: 179 additions & 7 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13565,7 +13565,6 @@ CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode
// Reduction all expand helper: dst_lane{0..(simd-1)} = src_lane{0} OP src_lane{1}
void EmitPass::ReductionExpandHelper(e_opcode op, VISA_Type type, CVariable* src, CVariable* dst)
{
const bool is64bitType = ScanReduceIs64BitType(type);
const bool isInt64Mul = ScanReduceIsInt64Mul(op, type);
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type);

Expand Down Expand Up @@ -13878,6 +13877,110 @@ void EmitPass::ReductionClusteredExpandHelper(e_opcode op, VISA_Type type, SIMDM
}
}

void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, CVariable* dst )
{
const bool isInt64Mul = ScanReduceIsInt64Mul( op, type );
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded( op, type );

uint16_t srcElementCount = src->GetNumberElement(); // total elements in reduction tree
uint16_t reductionElementCount = srcElementCount / dst->GetNumberElement(); // number of elements participating per reduction
// Build reduction tree layers
while( srcElementCount > dst->GetNumberElement() )
{
// Each layer operation merges multiple separate reduction intermediary steps
// Calculate max lanes per operation and number of merged reduction operations for current layer
SIMDMode maxSimdMode = ( m_currShader->m_dispatchSize == SIMDMode::SIMD32 && m_currShader->m_numberInstance > 1 ) ? SIMDMode::SIMD16 : m_currShader->m_dispatchSize;
SIMDMode layerMaxSimdMode = lanesToSIMDMode( min( numLanes( maxSimdMode ), (uint16_t)( srcElementCount >> 1 ) ) );
uint16_t layerMaxSimdLanes = numLanes( layerMaxSimdMode );
uint16_t src1Offset = reductionElementCount >> 1;
unsigned int numIterations = srcElementCount / ( 2 * layerMaxSimdLanes ); // number of reduction operations for current layer
for( unsigned int i = 0; i < numIterations; i++ )
{
// Get alias for src0, src1, and dst based on offsets and SIMD size
auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );
auto* layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes );

if( !int64EmulationNeeded )
{
m_encoder->SetNoMask();
m_encoder->SetSimdSize( layerMaxSimdMode );
// Set up correct vertical stride and width
m_encoder->SetSrcRegion( 0, reductionElementCount, ( reductionElementCount >> 1 ), 1 );
m_encoder->SetSrcRegion( 1, reductionElementCount, ( reductionElementCount >> 1 ), 1 );
m_encoder->GenericAlu( op, layerDst, layerSrc0, layerSrc1 );
m_encoder->Push();
}
else
{
if( isInt64Mul )
{
CVariable* tempMulSrc[ 2 ] = { layerSrc0, layerSrc1 };
Mul64( layerDst, tempMulSrc, layerMaxSimdMode, true /*noMask*/ );
}
else
{
IGC_ASSERT_MESSAGE( 0, "Unsupported" );
}
}
}

// Layer complete, total numer of elements and number of elements participating per reduction halved
srcElementCount >>= 1;
reductionElementCount >>= 1;
}

// copy fully reduced elements from src to dst
auto* finalLayerDst = m_currShader->GetNewAlias( src, type, 0, dst->GetNumberElement() );
m_encoder->SetNoMask();
m_encoder->SetSimdSize( lanesToSIMDMode( dst->GetNumberElement() ) );
m_encoder->Copy( dst, finalLayerDst );
m_encoder->Push();
}

// Recursive function that emits one or more joint reduction trees based on the joint output width
void EmitPass::emitReductionTrees( e_opcode op, VISA_Type type, SIMDMode simdMode, CVariable* src, CVariable* dst, unsigned int startIdx, unsigned int endIdx )
{
unsigned int numGroups = endIdx - startIdx + 1;
// lanes for final joint reduction
uint16_t simdLanes = numLanes( simdMode );
if( numGroups >= simdLanes )
{
// Do full tree reduction
unsigned int reductionElements = src->GetNumberElement() / dst->GetNumberElement();
unsigned int groupReductionElementCount = reductionElements * simdLanes;
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount );
CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes);
emitReductionTree( op, type, srcAlias, dstAlias );
// Start new recursive tree if any elements are left
if ( numGroups > simdLanes )
{
emitReductionTrees( op, type, simdMode, src, dst, startIdx + simdLanes, endIdx );
}
}
else
{
// Overshoot, try lower SIMD for the final reduction op
// TODO: Instead of trying lower SIMD, could generate simdLanes wide final join instruction, and pass in identity/0/don't care values for unused joins
// However, this will require a change to WaveAllJointReduction to generate intrinsic calls with fixed vector width to ensure the vector source variable used is generated with the proper bounds
// or logic to copy the vector source variable to a simdLane * simdLane sized variable along with logic to generate only the necessary operation on that varaible
switch( simdMode )
{
case SIMDMode::SIMD32:
return emitReductionTrees( op, type, SIMDMode::SIMD16, src, dst, startIdx, endIdx );
case SIMDMode::SIMD16:
return emitReductionTrees( op, type, SIMDMode::SIMD8, src, dst, startIdx, endIdx );
case SIMDMode::SIMD8:
return emitReductionTrees( op, type, SIMDMode::SIMD4, src, dst, startIdx, endIdx );
case SIMDMode::SIMD4:
return emitReductionTrees( op, type, SIMDMode::SIMD2, src, dst, startIdx, endIdx );
case SIMDMode::SIMD2:
default:
return emitReductionTrees( op, type, SIMDMode::SIMD1, src, dst, startIdx, endIdx );
}
}
}

// do reduction and accumulate all the activate channels, return a uniform
void EmitPass::emitReductionAll(
e_opcode op, uint64_t identityValue, VISA_Type type, bool negate, CVariable* src, CVariable* dst)
Expand All @@ -13893,8 +13996,6 @@ void EmitPass::emitReductionAll(
}
else
{
const SIMDMode simd = SIMDMode::SIMD16;

CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */,
src, nullptr /* dst */);

Expand Down Expand Up @@ -21891,15 +21992,86 @@ void EmitPass::emitWaveAll(llvm::GenIntrinsicInst* inst)
{
ForceDMask();
}
m_encoder->SetSubSpanDestination( false );
CVariable* src = GetSymbol(inst->getOperand(0));
CVariable* dst = m_destination;
const WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
VISA_Type type;
e_opcode opCode;
uint64_t identity = 0;
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
CVariable* dst = m_destination;
m_encoder->SetSubSpanDestination(false);
emitReductionAll(opCode, identity, type, false, src, dst);
if( inst->getOperand( 0 )->getType()->isVectorTy() )
{
// Joint Reduction optimzation from multiple consecutive independent wave ops, can construct wider reduction tree
GetReductionOp( op, cast<VectorType>( inst->getOperand( 0 )->getType() )->getElementType(), identity, opCode, type );

if( m_currShader->m_dispatchSize == SIMDMode::SIMD32 && m_currShader->m_numberInstance > 1 )
{
// Dual SIMD16 mode, use 1 SIMD16 inst per reduction for first layer to reduce 32 elements down to 16
CVariable* reduceSrc = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrc_" ), src->getName().getCString() ) );
CVariable* reduceSrcSecondHalf = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrcSecondHalf_" ), src->getName().getCString() ) );

const bool isInt64Mul = ScanReduceIsInt64Mul( opCode, type );
const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded( opCode, type );

// Explicitly generate First layer (Technically 0th layer since no operations are joint yet, we are still operating within a single reduction op)
for( uint16_t i = 0; i < dst->GetNumberElement(); i++ )
{
// Prepare reduceSrc
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias );

// Prepare reduceSrcSecondHalf
CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
ScanReducePrepareSrc( type, identity, false, true, srcSecondHalfAlias, reduceSrcSecondHalfAlias );

// Emit correct operations
if( !int64EmulationNeeded )
{
m_encoder->SetNoMask();
m_encoder->SetSimdSize( SIMDMode::SIMD16 );
m_encoder->GenericAlu( opCode, reduceSrcAlias, reduceSrcAlias, reduceSrcSecondHalfAlias );
m_encoder->Push();
}
else
{
if( isInt64Mul )
{
CVariable* tmpMulSrc[ 2 ] = { reduceSrcAlias, reduceSrcSecondHalfAlias };
Mul64( reduceSrcAlias, tmpMulSrc, SIMDMode::SIMD16, true );
}
else
{
IGC_ASSERT_MESSAGE( 0, "Unsupported" );
}
}
}

// Now that 32 elements per reduction have been reduced to 16 in layer 0, can proceed with regular reduction tree implementation using SIMD16
emitReductionTrees( opCode, type, SIMDMode::SIMD16, reduceSrc, dst, 0, dst->GetNumberElement() - 1 );
}
else
{
CVariable* reduceSrc = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrc_" ), src->getName().getCString() ) );
// Prepare reduceSrc for all elements
for( int i = 0; i < dst->GetNumberElement(); i++ )
{
CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) );
ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias );
}

emitReductionTrees( opCode, type, m_currShader->m_dispatchSize, reduceSrc, dst, 0, dst->GetNumberElement() - 1 );
}
}
else
{
// Single WaveAll, emit base reduction tree
GetReductionOp( op, inst->getOperand( 0 )->getType(), identity, opCode, type );
emitReductionAll( opCode, identity, type, false, src, dst );
}

if (disableHelperLanes)
{
ResetVMask();
Expand Down
13 changes: 13 additions & 0 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,19 @@ class EmitPass : public llvm::FunctionPass
bool negate,
CVariable* src,
CVariable* dst);
void emitReductionTree(
e_opcode op,
VISA_Type type,
CVariable* src,
CVariable* dst );
void emitReductionTrees(
e_opcode op,
VISA_Type type,
SIMDMode simdMode,
CVariable* src,
CVariable* dst,
unsigned int startIdx,
unsigned int endIdx );
void emitReductionClustered(
const e_opcode op,
const uint64_t identityValue,
Expand Down
6 changes: 6 additions & 0 deletions IGC/Compiler/CISACodeGen/ShaderCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ SPDX-License-Identifier: MIT
#include "Compiler/Optimizer/BarrierControlFlowOptimization.hpp"
#include "Compiler/Optimizer/RuntimeValueVectorExtractPass.h"
#include "Compiler/Optimizer/WaveShuffleIndexSinking.hpp"
#include "Compiler/Optimizer/WaveAllJointReduction.hpp"
#include "Compiler/MetaDataApi/PurgeMetaDataUtils.hpp"
#include "Compiler/HandleLoadStoreInstructions.hpp"
#include "Compiler/CustomSafeOptPass.hpp"
Expand Down Expand Up @@ -1869,6 +1870,11 @@ void OptimizeIR(CodeGenContext* const pContext)

mpm.add(llvm::createDeadCodeEliminationPass());

if( IGC_IS_FLAG_ENABLED(EnableWaveAllJointReduction) )
{
mpm.add( createWaveAllJointReduction() );
}

if (IGC_IS_FLAG_ENABLED(EnableIntDivRemCombine)) {
// simplify rem if the quotient is availble
//
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ void initializeVectorBitCastOptPass(llvm::PassRegistry&);
void initializeVectorPreProcessPass(llvm::PassRegistry&);
void initializeVectorProcessPass(llvm::PassRegistry&);
void initializeVerificationPassPass(llvm::PassRegistry&);
void initializeWaveAllJointReductionPass(llvm::PassRegistry&);
void initializeWGFuncResolutionPass(llvm::PassRegistry&);
void initializeWIAnalysisPass(llvm::PassRegistry&);
void initializeWIFuncResolutionPass(llvm::PassRegistry&);
Expand Down
2 changes: 2 additions & 0 deletions IGC/Compiler/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ set(IGC_BUILD__SRC__Optimizer
"${CMAKE_CURRENT_SOURCE_DIR}/RuntimeValueVectorExtractPass.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/BarrierControlFlowOptimization.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/WaveShuffleIndexSinking.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/WaveAllJointReduction.cpp"
)

set(IGC_BUILD__SRC__Compiler_Optimizer
Expand Down Expand Up @@ -61,6 +62,7 @@ set(IGC_BUILD__HDR__Optimizer
"${CMAKE_CURRENT_SOURCE_DIR}/RuntimeValueVectorExtractPass.h"
"${CMAKE_CURRENT_SOURCE_DIR}/BarrierControlFlowOptimization.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/WaveShuffleIndexSinking.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/WaveAllJointReduction.hpp"
)

set(IGC_BUILD__HDR__Optimizer
Expand Down
Loading

0 comments on commit be108bd

Please sign in to comment.