diff --git a/IGC/Compiler/CISACodeGen/EmitVISAPass.cpp b/IGC/Compiler/CISACodeGen/EmitVISAPass.cpp index 46a40e2e08b5..1d9d78b12537 100644 --- a/IGC/Compiler/CISACodeGen/EmitVISAPass.cpp +++ b/IGC/Compiler/CISACodeGen/EmitVISAPass.cpp @@ -13565,7 +13565,6 @@ CVariable* EmitPass::ReductionReduceHelper(e_opcode op, VISA_Type type, SIMDMode // Reduction all expand helper: dst_lane{0..(simd-1)} = src_lane{0} OP src_lane{1} void EmitPass::ReductionExpandHelper(e_opcode op, VISA_Type type, CVariable* src, CVariable* dst) { - const bool is64bitType = ScanReduceIs64BitType(type); const bool isInt64Mul = ScanReduceIsInt64Mul(op, type); const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded(op, type); @@ -13878,6 +13877,110 @@ void EmitPass::ReductionClusteredExpandHelper(e_opcode op, VISA_Type type, SIMDM } } +void EmitPass::emitReductionTree( e_opcode op, VISA_Type type, CVariable* src, CVariable* dst ) +{ + const bool isInt64Mul = ScanReduceIsInt64Mul( op, type ); + const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded( op, type ); + + uint16_t srcElementCount = src->GetNumberElement(); // total elements in reduction tree + uint16_t reductionElementCount = srcElementCount / dst->GetNumberElement(); // number of elements participating per reduction + // Build reduction tree layers + while( srcElementCount > dst->GetNumberElement() ) + { + // Each layer operation merges multiple separate reduction intermediary steps + // Calculate max lanes per operation and number of merged reduction operations for current layer + SIMDMode maxSimdMode = ( m_currShader->m_dispatchSize == SIMDMode::SIMD32 && m_currShader->m_numberInstance > 1 ) ? SIMDMode::SIMD16 : m_currShader->m_dispatchSize; + SIMDMode layerMaxSimdMode = lanesToSIMDMode( min( numLanes( maxSimdMode ), (uint16_t)( srcElementCount >> 1 ) ) ); + uint16_t layerMaxSimdLanes = numLanes( layerMaxSimdMode ); + uint16_t src1Offset = reductionElementCount >> 1; + unsigned int numIterations = srcElementCount / ( 2 * layerMaxSimdLanes ); // number of reduction operations for current layer + for( unsigned int i = 0; i < numIterations; i++ ) + { + // Get alias for src0, src1, and dst based on offsets and SIMD size + auto* layerSrc0 = m_currShader->GetNewAlias( src, type, i * 2 * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes ); + auto* layerSrc1 = m_currShader->GetNewAlias( src, type, ( i * 2 * layerMaxSimdLanes + src1Offset ) * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes ); + auto* layerDst = m_currShader->GetNewAlias( src, type, i * layerMaxSimdLanes * m_encoder->GetCISADataTypeSize( type ), layerMaxSimdLanes ); + + if( !int64EmulationNeeded ) + { + m_encoder->SetNoMask(); + m_encoder->SetSimdSize( layerMaxSimdMode ); + // Set up correct vertical stride and width + m_encoder->SetSrcRegion( 0, reductionElementCount, ( reductionElementCount >> 1 ), 1 ); + m_encoder->SetSrcRegion( 1, reductionElementCount, ( reductionElementCount >> 1 ), 1 ); + m_encoder->GenericAlu( op, layerDst, layerSrc0, layerSrc1 ); + m_encoder->Push(); + } + else + { + if( isInt64Mul ) + { + CVariable* tempMulSrc[ 2 ] = { layerSrc0, layerSrc1 }; + Mul64( layerDst, tempMulSrc, layerMaxSimdMode, true /*noMask*/ ); + } + else + { + IGC_ASSERT_MESSAGE( 0, "Unsupported" ); + } + } + } + + // Layer complete, total numer of elements and number of elements participating per reduction halved + srcElementCount >>= 1; + reductionElementCount >>= 1; + } + + // copy fully reduced elements from src to dst + auto* finalLayerDst = m_currShader->GetNewAlias( src, type, 0, dst->GetNumberElement() ); + m_encoder->SetNoMask(); + m_encoder->SetSimdSize( lanesToSIMDMode( dst->GetNumberElement() ) ); + m_encoder->Copy( dst, finalLayerDst ); + m_encoder->Push(); +} + +// Recursive function that emits one or more joint reduction trees based on the joint output width +void EmitPass::emitReductionTrees( e_opcode op, VISA_Type type, SIMDMode simdMode, CVariable* src, CVariable* dst, unsigned int startIdx, unsigned int endIdx ) +{ + unsigned int numGroups = endIdx - startIdx + 1; + // lanes for final joint reduction + uint16_t simdLanes = numLanes( simdMode ); + if( numGroups >= simdLanes ) + { + // Do full tree reduction + unsigned int reductionElements = src->GetNumberElement() / dst->GetNumberElement(); + unsigned int groupReductionElementCount = reductionElements * simdLanes; + CVariable* srcAlias = m_currShader->GetNewAlias( src, type, startIdx * reductionElements * m_encoder->GetCISADataTypeSize( type ), groupReductionElementCount ); + CVariable* dstAlias = m_currShader->GetNewAlias( dst, type, startIdx * m_encoder->GetCISADataTypeSize( type ), simdLanes); + emitReductionTree( op, type, srcAlias, dstAlias ); + // Start new recursive tree if any elements are left + if ( numGroups > simdLanes ) + { + emitReductionTrees( op, type, simdMode, src, dst, startIdx + simdLanes, endIdx ); + } + } + else + { + // Overshoot, try lower SIMD for the final reduction op + // TODO: Instead of trying lower SIMD, could generate simdLanes wide final join instruction, and pass in identity/0/don't care values for unused joins + // However, this will require a change to WaveAllJointReduction to generate intrinsic calls with fixed vector width to ensure the vector source variable used is generated with the proper bounds + // or logic to copy the vector source variable to a simdLane * simdLane sized variable along with logic to generate only the necessary operation on that varaible + switch( simdMode ) + { + case SIMDMode::SIMD32: + return emitReductionTrees( op, type, SIMDMode::SIMD16, src, dst, startIdx, endIdx ); + case SIMDMode::SIMD16: + return emitReductionTrees( op, type, SIMDMode::SIMD8, src, dst, startIdx, endIdx ); + case SIMDMode::SIMD8: + return emitReductionTrees( op, type, SIMDMode::SIMD4, src, dst, startIdx, endIdx ); + case SIMDMode::SIMD4: + return emitReductionTrees( op, type, SIMDMode::SIMD2, src, dst, startIdx, endIdx ); + case SIMDMode::SIMD2: + default: + return emitReductionTrees( op, type, SIMDMode::SIMD1, src, dst, startIdx, endIdx ); + } + } +} + // do reduction and accumulate all the activate channels, return a uniform void EmitPass::emitReductionAll( e_opcode op, uint64_t identityValue, VISA_Type type, bool negate, CVariable* src, CVariable* dst) @@ -13893,8 +13996,6 @@ void EmitPass::emitReductionAll( } else { - const SIMDMode simd = SIMDMode::SIMD16; - CVariable* srcH2 = ScanReducePrepareSrc(type, identityValue, negate, true /* secondHalf */, src, nullptr /* dst */); @@ -21891,15 +21992,86 @@ void EmitPass::emitWaveAll(llvm::GenIntrinsicInst* inst) { ForceDMask(); } + m_encoder->SetSubSpanDestination( false ); CVariable* src = GetSymbol(inst->getOperand(0)); + CVariable* dst = m_destination; const WaveOps op = static_cast(cast(inst->getOperand(1))->getZExtValue()); VISA_Type type; e_opcode opCode; uint64_t identity = 0; - GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type); - CVariable* dst = m_destination; - m_encoder->SetSubSpanDestination(false); - emitReductionAll(opCode, identity, type, false, src, dst); + if( inst->getOperand( 0 )->getType()->isVectorTy() ) + { + // Joint Reduction optimzation from multiple consecutive independent wave ops, can construct wider reduction tree + GetReductionOp( op, cast( inst->getOperand( 0 )->getType() )->getElementType(), identity, opCode, type ); + + if( m_currShader->m_dispatchSize == SIMDMode::SIMD32 && m_currShader->m_numberInstance > 1 ) + { + // Dual SIMD16 mode, use 1 SIMD16 inst per reduction for first layer to reduce 32 elements down to 16 + CVariable* reduceSrc = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrc_" ), src->getName().getCString() ) ); + CVariable* reduceSrcSecondHalf = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrcSecondHalf_" ), src->getName().getCString() ) ); + + const bool isInt64Mul = ScanReduceIsInt64Mul( opCode, type ); + const bool int64EmulationNeeded = ScanReduceIsInt64EmulationNeeded( opCode, type ); + + // Explicitly generate First layer (Technically 0th layer since no operations are joint yet, we are still operating within a single reduction op) + for( uint16_t i = 0; i < dst->GetNumberElement(); i++ ) + { + // Prepare reduceSrc + CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias ); + + // Prepare reduceSrcSecondHalf + CVariable* srcSecondHalfAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + CVariable* reduceSrcSecondHalfAlias = m_currShader->GetNewAlias( reduceSrcSecondHalf, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + ScanReducePrepareSrc( type, identity, false, true, srcSecondHalfAlias, reduceSrcSecondHalfAlias ); + + // Emit correct operations + if( !int64EmulationNeeded ) + { + m_encoder->SetNoMask(); + m_encoder->SetSimdSize( SIMDMode::SIMD16 ); + m_encoder->GenericAlu( opCode, reduceSrcAlias, reduceSrcAlias, reduceSrcSecondHalfAlias ); + m_encoder->Push(); + } + else + { + if( isInt64Mul ) + { + CVariable* tmpMulSrc[ 2 ] = { reduceSrcAlias, reduceSrcSecondHalfAlias }; + Mul64( reduceSrcAlias, tmpMulSrc, SIMDMode::SIMD16, true ); + } + else + { + IGC_ASSERT_MESSAGE( 0, "Unsupported" ); + } + } + } + + // Now that 32 elements per reduction have been reduced to 16 in layer 0, can proceed with regular reduction tree implementation using SIMD16 + emitReductionTrees( opCode, type, SIMDMode::SIMD16, reduceSrc, dst, 0, dst->GetNumberElement() - 1 ); + } + else + { + CVariable* reduceSrc = m_currShader->GetNewVariable( src->GetNumberElement(), type, src->GetAlign(), CName( CName( "reduceSrc_" ), src->getName().getCString() ) ); + // Prepare reduceSrc for all elements + for( int i = 0; i < dst->GetNumberElement(); i++ ) + { + CVariable* srcAlias = m_currShader->GetNewAlias( src, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + CVariable* reduceSrcAlias = m_currShader->GetNewAlias( reduceSrc, type, i * numLanes( m_currShader->m_SIMDSize ) * m_encoder->GetCISADataTypeSize( type ), numLanes( m_currShader->m_SIMDSize ) ); + ScanReducePrepareSrc( type, identity, false, false, srcAlias, reduceSrcAlias ); + } + + emitReductionTrees( opCode, type, m_currShader->m_dispatchSize, reduceSrc, dst, 0, dst->GetNumberElement() - 1 ); + } + } + else + { + // Single WaveAll, emit base reduction tree + GetReductionOp( op, inst->getOperand( 0 )->getType(), identity, opCode, type ); + emitReductionAll( opCode, identity, type, false, src, dst ); + } + if (disableHelperLanes) { ResetVMask(); diff --git a/IGC/Compiler/CISACodeGen/EmitVISAPass.hpp b/IGC/Compiler/CISACodeGen/EmitVISAPass.hpp index dae96efc0c9a..aa4266efcfa3 100644 --- a/IGC/Compiler/CISACodeGen/EmitVISAPass.hpp +++ b/IGC/Compiler/CISACodeGen/EmitVISAPass.hpp @@ -320,6 +320,19 @@ class EmitPass : public llvm::FunctionPass bool negate, CVariable* src, CVariable* dst); + void emitReductionTree( + e_opcode op, + VISA_Type type, + CVariable* src, + CVariable* dst ); + void emitReductionTrees( + e_opcode op, + VISA_Type type, + SIMDMode simdMode, + CVariable* src, + CVariable* dst, + unsigned int startIdx, + unsigned int endIdx ); void emitReductionClustered( const e_opcode op, const uint64_t identityValue, diff --git a/IGC/Compiler/CISACodeGen/ShaderCodeGen.cpp b/IGC/Compiler/CISACodeGen/ShaderCodeGen.cpp index 455112f206b9..85e18b32471e 100644 --- a/IGC/Compiler/CISACodeGen/ShaderCodeGen.cpp +++ b/IGC/Compiler/CISACodeGen/ShaderCodeGen.cpp @@ -102,6 +102,7 @@ SPDX-License-Identifier: MIT #include "Compiler/Optimizer/BarrierControlFlowOptimization.hpp" #include "Compiler/Optimizer/RuntimeValueVectorExtractPass.h" #include "Compiler/Optimizer/WaveShuffleIndexSinking.hpp" +#include "Compiler/Optimizer/WaveAllJointReduction.hpp" #include "Compiler/MetaDataApi/PurgeMetaDataUtils.hpp" #include "Compiler/HandleLoadStoreInstructions.hpp" #include "Compiler/CustomSafeOptPass.hpp" @@ -1869,6 +1870,11 @@ void OptimizeIR(CodeGenContext* const pContext) mpm.add(llvm::createDeadCodeEliminationPass()); + if( IGC_IS_FLAG_ENABLED(EnableWaveAllJointReduction) ) + { + mpm.add( createWaveAllJointReduction() ); + } + if (IGC_IS_FLAG_ENABLED(EnableIntDivRemCombine)) { // simplify rem if the quotient is availble // diff --git a/IGC/Compiler/InitializePasses.h b/IGC/Compiler/InitializePasses.h index a67b07331776..bbd45cb207a3 100644 --- a/IGC/Compiler/InitializePasses.h +++ b/IGC/Compiler/InitializePasses.h @@ -203,6 +203,7 @@ void initializeVectorBitCastOptPass(llvm::PassRegistry&); void initializeVectorPreProcessPass(llvm::PassRegistry&); void initializeVectorProcessPass(llvm::PassRegistry&); void initializeVerificationPassPass(llvm::PassRegistry&); +void initializeWaveAllJointReductionPass(llvm::PassRegistry&); void initializeWGFuncResolutionPass(llvm::PassRegistry&); void initializeWIAnalysisPass(llvm::PassRegistry&); void initializeWIFuncResolutionPass(llvm::PassRegistry&); diff --git a/IGC/Compiler/Optimizer/CMakeLists.txt b/IGC/Compiler/Optimizer/CMakeLists.txt index 17e10d1ea5a3..5008c2236ace 100644 --- a/IGC/Compiler/Optimizer/CMakeLists.txt +++ b/IGC/Compiler/Optimizer/CMakeLists.txt @@ -34,6 +34,7 @@ set(IGC_BUILD__SRC__Optimizer "${CMAKE_CURRENT_SOURCE_DIR}/RuntimeValueVectorExtractPass.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/BarrierControlFlowOptimization.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/WaveShuffleIndexSinking.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/WaveAllJointReduction.cpp" ) set(IGC_BUILD__SRC__Compiler_Optimizer @@ -61,6 +62,7 @@ set(IGC_BUILD__HDR__Optimizer "${CMAKE_CURRENT_SOURCE_DIR}/RuntimeValueVectorExtractPass.h" "${CMAKE_CURRENT_SOURCE_DIR}/BarrierControlFlowOptimization.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/WaveShuffleIndexSinking.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/WaveAllJointReduction.hpp" ) set(IGC_BUILD__HDR__Optimizer diff --git a/IGC/Compiler/Optimizer/WaveAllJointReduction.cpp b/IGC/Compiler/Optimizer/WaveAllJointReduction.cpp new file mode 100644 index 000000000000..344308fc2af9 --- /dev/null +++ b/IGC/Compiler/Optimizer/WaveAllJointReduction.cpp @@ -0,0 +1,162 @@ +/*========================== begin_copyright_notice ============================ + +Copyright (C) 2024 Intel Corporation + +SPDX-License-Identifier: MIT + +============================= end_copyright_notice ===========================*/ + +#include +#include "WaveAllJointReduction.hpp" +#include "Compiler/IGCPassSupport.h" +#include "Compiler/InitializePasses.h" +#include +#include "common/LLVMWarningsPush.hpp" +#include +#include +#include +#include "common/LLVMWarningsPop.hpp" + +#define DEBUG_TYPE "igc-wave-all-joint-reduction" + +using namespace IGC; +using namespace llvm; + +namespace IGC +{ + class WaveAllJointReductionImpl : public InstVisitor + { + public: + WaveAllJointReductionImpl( Function& F ) : F( F ) {} + bool run(); + void visitCallInst( CallInst& callInst ); + private: + Value* createInsertElements( SmallVector& mergeList ); + void createExtractElements( SmallVector& mergeList, WaveAllIntrinsic* waveAllJoint ); + Function& F; + DenseSet ToDelete; + bool Changed = false; + }; + + class WaveAllJointReduction: public FunctionPass + { + public: + static char ID; + WaveAllJointReduction() : FunctionPass( ID ) {} + + llvm::StringRef getPassName() const override + { + return "WaveAllJointReduction"; + } + bool runOnFunction( Function& F ) override; + }; + + FunctionPass* createWaveAllJointReduction() + { + return new WaveAllJointReduction(); + } +} + +Value* WaveAllJointReductionImpl::createInsertElements( SmallVector& mergeList ) +{ + IRBuilder<> builder( mergeList.front() ); + auto* vecType = VectorType::get( mergeList.front()->getSrc()->getType(), mergeList.size(), false ); + auto* vec = builder.CreateInsertElement( UndefValue::get( vecType ), mergeList.front()->getSrc(), (uint64_t)0, "waveAllSrc" ); + for( uint64_t i = 1; i < mergeList.size(); i++ ) + { + vec = builder.CreateInsertElement( vec, mergeList[ i ]->getSrc(), i, "waveAllSrc" ); + } + return vec; +} + +void WaveAllJointReductionImpl::createExtractElements( SmallVector& mergeList, WaveAllIntrinsic* waveAllJoint ) +{ + IRBuilder<> builder( mergeList.front() ); + for( uint64_t i = 0; i < mergeList.size(); i++ ) + { + auto* res = builder.CreateExtractElement( waveAllJoint, i, "waveAllDst" ); + mergeList[ i ]->replaceAllUsesWith( res ); + } +} + +void WaveAllJointReductionImpl::visitCallInst( CallInst& callInst ) +{ + + if( auto* waveAllInst = dyn_cast( &callInst ) ) + { + // marked as delete because it was already merged with prior insts + if( ToDelete.count( waveAllInst ) ) + { + return; + } + + // Optimization already happened, first operand is already vector + if( waveAllInst->getSrc()->getType()->isVectorTy() ) + { + return; + } + + SmallVector mergeList{ waveAllInst }; + + // For locality, only look at consecutive instructions since non-consecutive instructions may require sinking the final vector WaveAll instruction to where the last joined WaveAll is to satisfy proper domination of each WaveAll's Src + // TODO: If needed, a complicated analysis could find non-consecutive WaveAll instructions that are able to participate in WaveAll joint reduction, but seems like an edge case for now + Instruction* I = waveAllInst->getNextNode(); + while( I != waveAllInst->getParent()->getTerminator() ) + { + auto* nextWaveAllInst = dyn_cast( I ); + // TODO: Can check helper lane mode here if necessary, unsure whether that changes anything + if( !nextWaveAllInst || nextWaveAllInst->getSrc()->getType()->isVectorTy() || nextWaveAllInst->getSrc()->getType() != waveAllInst->getSrc()->getType() || nextWaveAllInst->getOpKind() != waveAllInst->getOpKind() ) + { + break; + } + + mergeList.push_back( nextWaveAllInst ); + + I = I->getNextNode(); + } + + if( mergeList.size() > 1 ) + { + // Multiple WaveAll operations eligible to participate in joint operation + auto* arg0 = createInsertElements( mergeList ); + IRBuilder<> builder( mergeList.front() ); + Type* funcType[] = { arg0->getType(), Type::getInt8Ty( builder.getContext() ), Type::getInt32Ty( builder.getContext() ) }; + Function* waveAllJointFunc = GenISAIntrinsic::getDeclaration( mergeList.front()->getModule(), GenISAIntrinsic::GenISA_WaveAll, funcType ); + + auto* waveAllJoint = builder.CreateCall( waveAllJointFunc, { arg0, waveAllInst->getOperand( 1 ), waveAllInst->getOperand( 2 ) }, "waveAllJoint" ); + createExtractElements( mergeList, cast( waveAllJoint ) ); + + // Mark merged WaveAll ops participating in joint operation for deletion + for( auto* mergedInst : mergeList ) + { + ToDelete.insert( mergedInst ); + } + Changed = true; + } + } +} + +bool WaveAllJointReductionImpl::run() +{ + visit( F ); + for( auto* mergedWaveAllInst : ToDelete ) + { + mergedWaveAllInst->eraseFromParent(); + } + return Changed; +} + +bool WaveAllJointReduction::runOnFunction( Function& F ) +{ + WaveAllJointReductionImpl WorkerInstance( F ); + return WorkerInstance.run(); +} + +char WaveAllJointReduction::ID = 0; + +#define PASS_FLAG "igc-wave-all-joint-reduction" +#define PASS_DESCRIPTION "WaveAllJointReduction" +#define PASS_CFG_ONLY false +#define PASS_ANALYSIS false +IGC_INITIALIZE_PASS_BEGIN( WaveAllJointReduction, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS ) +IGC_INITIALIZE_PASS_END( WaveAllJointReduction, PASS_FLAG, PASS_DESCRIPTION, PASS_CFG_ONLY, PASS_ANALYSIS ) diff --git a/IGC/Compiler/Optimizer/WaveAllJointReduction.hpp b/IGC/Compiler/Optimizer/WaveAllJointReduction.hpp new file mode 100644 index 000000000000..c0897d058752 --- /dev/null +++ b/IGC/Compiler/Optimizer/WaveAllJointReduction.hpp @@ -0,0 +1,18 @@ +/*========================== begin_copyright_notice ============================ + +Copyright (C) 2024 Intel Corporation + +SPDX-License-Identifier: MIT + +============================= end_copyright_notice ===========================*/ + +#pragma once + +#include "common/LLVMWarningsPush.hpp" +#include +#include "common/LLVMWarningsPop.hpp" + +namespace IGC +{ + llvm::FunctionPass* createWaveAllJointReduction(); +} // namespace IGC diff --git a/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-dual-simd16-group4.ll b/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-dual-simd16-group4.ll new file mode 100644 index 000000000000..5ffca1822f72 --- /dev/null +++ b/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-dual-simd16-group4.ll @@ -0,0 +1,132 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2024 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= +; REQUIRES: regkeys +; +; RUN: igc_opt -platformdg2 -igc-emit-visa %s -inputcs -simd-mode 32 -regkey DumpVISAASMToConsole | FileCheck %s +; ------------------------------------------------ +; EmitVISAPass +; ------------------------------------------------ +target datalayout = "e-p:32:32:32-p1:64:64:64-p2:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:32-f32:32:32-f64:32:32-v64:32:32-v128:32:32-a0:0:32-n8:16:32-S32" +target triple = "dxil-ms-dx" + +@ThreadGroupSize_X = constant i32 1 +@ThreadGroupSize_Y = constant i32 1 +@ThreadGroupSize_Z = constant i32 32 + +; Function Attrs: null_pointer_is_valid +define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_value_2) #0 { + %src = inttoptr i32 %runtime_value_0 to <4 x float> addrspace(2490368)* + %dst = inttoptr i32 %runtime_value_2 to <4 x float> addrspace(2490369)* + %lane = call i16 @llvm.genx.GenISA.simdLaneId() + %lane32 = zext i16 %lane to i32 + %shl_runtime_value_1 = shl i32 %runtime_value_1, 2 + %shuffle_0 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %shl_runtime_value_1, i32 0, i32 0) + %shl_lane32 = shl i32 %lane32, 2 + %add_0 = add i32 %shuffle_0, %shl_lane32 + %a = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %add_0, i32 4, i1 false) + %shuffle_1 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %shl_runtime_value_1, i32 1, i32 0) + %add_1 = add i32 %shuffle_1, %shl_lane32 + %b = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %add_1, i32 4, i1 false) + %shuffle_2 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %shl_runtime_value_1, i32 2, i32 0) + %add_2 = add i32 %shuffle_2, %shl_lane32 + %c = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %add_2, i32 4, i1 false) + %shuffle_3 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %shl_runtime_value_1, i32 3, i32 0) + %add_3 = add i32 %shuffle_3, %shl_lane32 + %d = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %add_3, i32 4, i1 false) + %waveAllSrc0 = insertelement <4 x i32> undef, i32 %a, i64 0 + %waveAllSrc1 = insertelement <4 x i32> %waveAllSrc0, i32 %b, i64 1 + %waveAllSrc2 = insertelement <4 x i32> %waveAllSrc1, i32 %c, i64 2 + %waveAllSrc3 = insertelement <4 x i32> %waveAllSrc2, i32 %d, i64 3 +; move operands to consecutive GRF space (generated from insertelement instructions) +; CHECK: mov (M1, 16) waveAllSrc0(0,0)<1> a(0,0)<1;1,0> +; CHECK: mov (M1, 16) waveAllSrc0(2,0)<1> b(0,0)<1;1,0> +; CHECK: mov (M1, 16) waveAllSrc0(4,0)<1> c(0,0)<1;1,0> +; CHECK: mov (M1, 16) waveAllSrc0(6,0)<1> d(0,0)<1;1,0> +; move operands (secondHalf) to consecutive GRF space (one-time use space for first reduction layer) +; CHECK: mov (M5, 16) waveAllSrc0_0(0,0)<1> a_0(0,0)<1;1,0> +; CHECK: mov (M5, 16) waveAllSrc0_0(2,0)<1> b_0(0,0)<1;1,0> +; CHECK: mov (M5, 16) waveAllSrc0_0(4,0)<1> c_0(0,0)<1;1,0> +; CHECK: mov (M5, 16) waveAllSrc0_0(6,0)<1> d_0(0,0)<1;1,0> + +; Identity operations + layer 0 (simd-16 reduction of a single variable across 32 lanes) +; CHECK: mov (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 16) reduceSrc_waveAllSrc0(0,0)<1> waveAllSrc0(0,0)<1;1,0> +; CHECK-NEXT: mov (M5_NM, 16) reduceSrcSecondHalf_waveAllSrc0(0,0)<1> 0x0:d +; CHECK-NEXT: mov (M5, 16) reduceSrcSecondHalf_waveAllSrc0(0,0)<1> waveAllSrc0_0(0,0)<1;1,0> +; CHECK-NEXT: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0> reduceSrcSecondHalf_waveAllSrc0(0,0)<1;1,0> +; CHECK: mov (M1_NM, 16) reduceSrc_waveAllSrc0(2,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 16) reduceSrc_waveAllSrc0(2,0)<1> waveAllSrc0(2,0)<1;1,0> +; CHECK-NEXT: mov (M5_NM, 16) reduceSrcSecondHalf_waveAllSrc0(2,0)<1> 0x0:d +; CHECK-NEXT: mov (M5, 16) reduceSrcSecondHalf_waveAllSrc0(2,0)<1> waveAllSrc0_0(2,0)<1;1,0> +; CHECK-NEXT: add (M1_NM, 16) reduceSrc_waveAllSrc0(2,0)<1> reduceSrc_waveAllSrc0(2,0)<1;1,0> reduceSrcSecondHalf_waveAllSrc0(2,0)<1;1,0> +; CHECK: mov (M1_NM, 16) reduceSrc_waveAllSrc0(4,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 16) reduceSrc_waveAllSrc0(4,0)<1> waveAllSrc0(4,0)<1;1,0> +; CHECK-NEXT: mov (M5_NM, 16) reduceSrcSecondHalf_waveAllSrc0(4,0)<1> 0x0:d +; CHECK-NEXT: mov (M5, 16) reduceSrcSecondHalf_waveAllSrc0(4,0)<1> waveAllSrc0_0(4,0)<1;1,0> +; CHECK-NEXT: add (M1_NM, 16) reduceSrc_waveAllSrc0(4,0)<1> reduceSrc_waveAllSrc0(4,0)<1;1,0> reduceSrcSecondHalf_waveAllSrc0(4,0)<1;1,0> +; CHECK: mov (M1_NM, 16) reduceSrc_waveAllSrc0(6,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 16) reduceSrc_waveAllSrc0(6,0)<1> waveAllSrc0(6,0)<1;1,0> +; CHECK-NEXT: mov (M5_NM, 16) reduceSrcSecondHalf_waveAllSrc0(6,0)<1> 0x0:d +; CHECK-NEXT: mov (M5, 16) reduceSrcSecondHalf_waveAllSrc0(6,0)<1> waveAllSrc0_0(6,0)<1;1,0> +; CHECK-NEXT: add (M1_NM, 16) reduceSrc_waveAllSrc0(6,0)<1> reduceSrc_waveAllSrc0(6,0)<1;1,0> reduceSrcSecondHalf_waveAllSrc0(6,0)<1;1,0> + +; Joint Reduction Tree +; layer 1 +; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<16;8,1> reduceSrc_waveAllSrc0(1,0)<16;8,1> +; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(2,0)<1> reduceSrc_waveAllSrc0(4,0)<16;8,1> reduceSrc_waveAllSrc0(5,0)<16;8,1> +; layer 2 +; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<8;4,1> reduceSrc_waveAllSrc0(0,4)<8;4,1> +; layer 3 +; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1> +; layer 4 +; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1> +; copy to dest +; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0> + %waveAllJoint = call <4 x i32> @llvm.genx.GenISA.WaveAll.v4i32.i8.i32(<4 x i32> %waveAllSrc3, i8 0, i32 0) + %res_a = extractelement <4 x i32> %waveAllJoint, i32 0 + %res_b = extractelement <4 x i32> %waveAllJoint, i32 1 + %res_c = extractelement <4 x i32> %waveAllJoint, i32 2 + %res_d = extractelement <4 x i32> %waveAllJoint, i32 3 +; Proper replacement in subsequent instructions +; CHECK: add (M1_NM, 1) join_c_d(0,0)<1> waveAllJoint(0,2)<0;1,0> waveAllJoint(0,3)<0;1,0> +; CHECK: add3 (M1_NM, 1) join_a_b_c_d(0,0)<1> waveAllJoint(0,0)<0;1,0> waveAllJoint(0,1)<0;1,0> join_c_d(0,0)<0;1,0> + %join_a_b = add i32 %res_a, %res_b + %join_c_d = add i32 %res_c, %res_d + %join_a_b_c_d = add i32 %join_a_b, %join_c_d + %store = insertelement <1 x i32> undef, i32 %join_a_b_c_d, i64 0 + call void @llvm.genx.GenISA.storerawvector.indexed.p2490377v4f32.v1i32(<4 x float> addrspace(2490369)* %dst, i32 0, <1 x i32> %store, i32 4, i1 false) + ret void +} + +declare i16 @llvm.genx.GenISA.simdLaneId() #1 + +declare i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)*, i32, i32, i1) #2 + +declare <4 x i32> @llvm.genx.GenISA.WaveAll.v4i32.i8.i32(<4 x i32>, i8, i32) #3 + +declare i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32, i32, i32) #4 + +declare void @llvm.genx.GenISA.storerawvector.indexed.p2490377v4f32.v1i32(<4 x float> addrspace(2490369)*, i32, <1 x i32>, i32, i1) #5 + +attributes #0 = { null_pointer_is_valid } +attributes #1 = { nounwind readnone } +attributes #2 = { argmemonly nounwind readonly } +attributes #3 = { convergent inaccessiblememonly nounwind } +attributes #4 = { convergent nounwind readnone } +attributes #5 = { argmemonly nounwind writeonly } + +!igc.functions = !{!0} +!IGCMetadata = !{!3} + +!0 = !{void (i32, i32, i32)* @CSMain, !1} +!1 = !{!2} +!2 = !{!"function_type", i32 0} +!3 = !{!"ModuleMD", !4} +!4 = !{!"FuncMD", !5, !6} +!5 = !{!"FuncMDMap[0]", void (i32, i32, i32)* @CSMain} +!6 = !{!"FuncMDValue[0]"} \ No newline at end of file diff --git a/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-simd32-group17.ll b/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-simd32-group17.ll new file mode 100644 index 000000000000..7516109dc1e3 --- /dev/null +++ b/IGC/Compiler/tests/EmitVISAPass/wave-all-joint-reduction-simd32-group17.ll @@ -0,0 +1,237 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2024 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= +; REQUIRES: regkeys +; +; RUN: igc_opt -platformbmg -igc-emit-visa %s -inputcs -simd-mode 32 -regkey DumpVISAASMToConsole | FileCheck %s +; ------------------------------------------------ +; EmitVISAPass: Compare group of 17 WaveAll reductions participating in a joint reduction tree to a single WaveAll reduction +; Joint reduction emits 75 instructions in total after EmitVISAPass +; - includes 17 potentially unnecessary mov instructions to get the inputs into the GRF aligned space +; - includes 2 potentially unnecessary mov instructions to move the reduction tree results to the destination +; Compared to 7 (non-joint WaveAll reduction) * 17 = 119 for 17 consecutive non-joint WaveAll instructions if they were not merged +; ------------------------------------------------ +target datalayout = "e-p:32:32:32-p1:64:64:64-p2:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:32-f32:32:32-f64:32:32-v64:32:32-v128:32:32-a0:0:32-n8:16:32-S32" +target triple = "dxil-ms-dx" + +@ThreadGroupSize_X = constant i32 1 +@ThreadGroupSize_Y = constant i32 1 +@ThreadGroupSize_Z = constant i32 32 + +; Function Attrs: null_pointer_is_valid +define void @CSMain(i32 %runtime_value_0, i32 %runtime_value_1, i32 %runtime_value_2) #0 { + %src = inttoptr i32 %runtime_value_0 to <4 x float> addrspace(2490368)* + %dst = inttoptr i32 %runtime_value_2 to <4 x float> addrspace(2490369)* + %lane = call i16 @llvm.genx.GenISA.simdLaneId() + %lane32 = zext i16 %lane to i32 + %shuffle_0 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 0, i32 0) + %add_0 = add i32 %shuffle_0, %lane32 + %shl_0 = shl i32 %add_0, 2 + %a = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_0, i32 4, i1 false) + %shuffle_1 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 1, i32 0) + %add_1 = add i32 %shuffle_1, %lane32 + %shl_1 = shl i32 %add_1, 2 + %b = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_1, i32 4, i1 false) + %shuffle_2 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 2, i32 0) + %add_2 = add i32 %shuffle_2, %lane32 + %shl_2 = shl i32 %add_2, 2 + %c = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_2, i32 4, i1 false) + %shuffle_3 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 0, i32 0) + %add_3 = add i32 %shuffle_3, %lane32 + %shl_3 = shl i32 %add_3, 2 + %d = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_3, i32 4, i1 false) + %shuffle_4 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 0, i32 0) + %add_4 = add i32 %shuffle_4, %lane32 + %shl_4 = shl i32 %add_4, 2 + %e = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_4, i32 4, i1 false) + %shuffle_5 = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %runtime_value_1, i32 1, i32 0) + %add_5 = add i32 %shuffle_5, %lane32 + %shl_5 = shl i32 %add_5, 2 + %f = call i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)* %src, i32 %shl_5, i32 4, i1 false) + %waveAllSrc0 = insertelement <17 x i32> undef, i32 %add_0, i64 0 + %waveAllSrc1 = insertelement <17 x i32> %waveAllSrc0, i32 %shl_0, i64 1 + %waveAllSrc2 = insertelement <17 x i32> %waveAllSrc1, i32 %a, i64 2 + %waveAllSrc3 = insertelement <17 x i32> %waveAllSrc2, i32 %add_1, i64 3 + %waveAllSrc4 = insertelement <17 x i32> %waveAllSrc3, i32 %shl_1, i64 4 + %waveAllSrc5 = insertelement <17 x i32> %waveAllSrc4, i32 %b, i64 5 + %waveAllSrc6 = insertelement <17 x i32> %waveAllSrc5, i32 %add_2, i64 6 + %waveAllSrc7 = insertelement <17 x i32> %waveAllSrc6, i32 %shl_2, i64 7 + %waveAllSrc8 = insertelement <17 x i32> %waveAllSrc7, i32 %c, i64 8 + %waveAllSrc9 = insertelement <17 x i32> %waveAllSrc8, i32 %add_3, i64 9 + %waveAllSrc10 = insertelement <17 x i32> %waveAllSrc9, i32 %shl_3, i64 10 + %waveAllSrc11 = insertelement <17 x i32> %waveAllSrc10, i32 %d, i64 11 + %waveAllSrc12 = insertelement <17 x i32> %waveAllSrc11, i32 %add_4, i64 12 + %waveAllSrc13 = insertelement <17 x i32> %waveAllSrc12, i32 %shl_4, i64 13 + %waveAllSrc14 = insertelement <17 x i32> %waveAllSrc13, i32 %e, i64 14 + %waveAllSrc15 = insertelement <17 x i32> %waveAllSrc14, i32 %add_5, i64 15 + %waveAllSrc16 = insertelement <17 x i32> %waveAllSrc15, i32 %shl_5, i64 16 +; move operands to consecutive GRF space (generated from insertelement instructions, will likely be optimized away in the end) +; CHECK: mov (M1, 32) waveAllSrc0(0,0)<1> add_0(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(2,0)<1> shl_0(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(4,0)<1> a(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(6,0)<1> add_1(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(8,0)<1> shl_1(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(10,0)<1> b(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(12,0)<1> add_2(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(14,0)<1> shl_2(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(16,0)<1> c(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(18,0)<1> add_3(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(20,0)<1> shl_3(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(22,0)<1> d(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(24,0)<1> add_4(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(26,0)<1> shl_4(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(28,0)<1> e(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(30,0)<1> add_5(0,0)<1;1,0> +; CHECK: mov (M1, 32) waveAllSrc0(32,0)<1> shl_5(0,0)<1;1,0> + +; Identity operations +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(0,0)<1> waveAllSrc0(0,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(2,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(2,0)<1> waveAllSrc0(2,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(4,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(4,0)<1> waveAllSrc0(4,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(6,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(6,0)<1> waveAllSrc0(6,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(8,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(8,0)<1> waveAllSrc0(8,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(10,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(10,0)<1> waveAllSrc0(10,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(12,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(12,0)<1> waveAllSrc0(12,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(14,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(14,0)<1> waveAllSrc0(14,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(16,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(16,0)<1> waveAllSrc0(16,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(18,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(18,0)<1> waveAllSrc0(18,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(20,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(20,0)<1> waveAllSrc0(20,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(22,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(22,0)<1> waveAllSrc0(22,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(24,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(24,0)<1> waveAllSrc0(24,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(26,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(26,0)<1> waveAllSrc0(26,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(28,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(28,0)<1> waveAllSrc0(28,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(30,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(30,0)<1> waveAllSrc0(30,0)<1;1,0> +; CHECK: mov (M1_NM, 32) reduceSrc_waveAllSrc0(32,0)<1> 0x0:d +; CHECK-NEXT: mov (M1, 32) reduceSrc_waveAllSrc0(32,0)<1> waveAllSrc0(32,0)<1;1,0> +; Joint Reduction Tree (16-wide) +; layer 1 +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<32;16,1> reduceSrc_waveAllSrc0(1,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(2,0)<1> reduceSrc_waveAllSrc0(4,0)<32;16,1> reduceSrc_waveAllSrc0(5,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(4,0)<1> reduceSrc_waveAllSrc0(8,0)<32;16,1> reduceSrc_waveAllSrc0(9,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(6,0)<1> reduceSrc_waveAllSrc0(12,0)<32;16,1> reduceSrc_waveAllSrc0(13,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(8,0)<1> reduceSrc_waveAllSrc0(16,0)<32;16,1> reduceSrc_waveAllSrc0(17,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(10,0)<1> reduceSrc_waveAllSrc0(20,0)<32;16,1> reduceSrc_waveAllSrc0(21,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(12,0)<1> reduceSrc_waveAllSrc0(24,0)<32;16,1> reduceSrc_waveAllSrc0(25,0)<32;16,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(14,0)<1> reduceSrc_waveAllSrc0(28,0)<32;16,1> reduceSrc_waveAllSrc0(29,0)<32;16,1> +; layer 2 +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<16;8,1> reduceSrc_waveAllSrc0(0,8)<16;8,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(2,0)<1> reduceSrc_waveAllSrc0(4,0)<16;8,1> reduceSrc_waveAllSrc0(4,8)<16;8,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(4,0)<1> reduceSrc_waveAllSrc0(8,0)<16;8,1> reduceSrc_waveAllSrc0(8,8)<16;8,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(6,0)<1> reduceSrc_waveAllSrc0(12,0)<16;8,1> reduceSrc_waveAllSrc0(12,8)<16;8,1> +; layer 3 +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<8;4,1> reduceSrc_waveAllSrc0(0,4)<8;4,1> +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(2,0)<1> reduceSrc_waveAllSrc0(4,0)<8;4,1> reduceSrc_waveAllSrc0(4,4)<8;4,1> +; layer 4 +; CHECK: add (M1_NM, 32) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<4;2,1> reduceSrc_waveAllSrc0(0,2)<4;2,1> +; layer 5 +; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(0,0)<1> reduceSrc_waveAllSrc0(0,0)<2;1,1> reduceSrc_waveAllSrc0(0,1)<2;1,1> +; copy to dest +; CHECK: mov (M1_NM, 1) waveAllJoint(0,0)<1> reduceSrc_waveAllSrc0(0,0)<1;1,0> +; Joint Reduction Tree (1-wide, leftover from splitting the 17-wide vector into 16 and 1, almost identical to existing non-joint reduction tree generated from scalar WaveAll intrinsic further below) +; CHECK: add (M1_NM, 16) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<32;16,1> reduceSrc_waveAllSrc0(33,0)<32;16,1> +; CHECK: add (M1_NM, 8) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<16;8,1> reduceSrc_waveAllSrc0(32,8)<16;8,1> +; CHECK: add (M1_NM, 4) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<8;4,1> reduceSrc_waveAllSrc0(32,4)<8;4,1> +; CHECK: add (M1_NM, 2) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<4;2,1> reduceSrc_waveAllSrc0(32,2)<4;2,1> +; CHECK: add (M1_NM, 1) reduceSrc_waveAllSrc0(32,0)<1> reduceSrc_waveAllSrc0(32,0)<2;1,1> reduceSrc_waveAllSrc0(32,1)<2;1,1> +; CHECK: mov (M1_NM, 1) waveAllJoint(1,0)<1> reduceSrc_waveAllSrc0(32,0)<1;1,0> + %waveAllJoint = call <17 x i32> @llvm.genx.GenISA.WaveAll.v17i32.i8.i32(<17 x i32> %waveAllSrc16, i8 0, i32 0) + %res_f = call i32 @llvm.genx.GenISA.WaveAll.i32.i8.i32(i32 %f, i8 0, i32 0) + %res_add_0 = extractelement <17 x i32> %waveAllJoint, i32 0 + %res_shl_0 = extractelement <17 x i32> %waveAllJoint, i32 1 + %res_a = extractelement <17 x i32> %waveAllJoint, i32 2 + %res_add_1 = extractelement <17 x i32> %waveAllJoint, i32 3 + %res_shl_1 = extractelement <17 x i32> %waveAllJoint, i32 4 + %res_b = extractelement <17 x i32> %waveAllJoint, i32 5 + %res_add_2 = extractelement <17 x i32> %waveAllJoint, i32 6 + %res_shl_2 = extractelement <17 x i32> %waveAllJoint, i32 7 + %res_c = extractelement <17 x i32> %waveAllJoint, i32 8 + %res_add_3 = extractelement <17 x i32> %waveAllJoint, i32 9 + %res_shl_3 = extractelement <17 x i32> %waveAllJoint, i32 10 + %res_d = extractelement <17 x i32> %waveAllJoint, i32 11 + %res_add_4 = extractelement <17 x i32> %waveAllJoint, i32 12 + %res_shl_4 = extractelement <17 x i32> %waveAllJoint, i32 13 + %res_e = extractelement <17 x i32> %waveAllJoint, i32 14 + %res_add_5 = extractelement <17 x i32> %waveAllJoint, i32 15 + %res_shl_5 = extractelement <17 x i32> %waveAllJoint, i32 16 +; Proper replacement in subsequent instructions +; CHECK: add (M1_NM, 1) join_a_0(0,0)<1> waveAllJoint(0,0)<0;1,0> waveAllJoint(0,1)<0;1,0> + %join_a_0 = add i32 %res_add_0, %res_shl_0 + %join_a_1 = add i32 %join_a_0, %res_a +; CHECK: add3 (M1_NM, 1) join_b_1(0,0)<1> waveAllJoint(0,3)<0;1,0> waveAllJoint(0,4)<0;1,0> waveAllJoint(0,5)<0;1,0> + %join_b_0 = add i32 %res_add_1, %res_shl_1 + %join_b_1 = add i32 %join_b_0, %res_b +; CHECK: add (M1_NM, 1) join_c_0(0,0)<1> waveAllJoint(0,6)<0;1,0> waveAllJoint(0,7)<0;1,0> + %join_c_0 = add i32 %res_add_2, %res_shl_2 + %join_c_1 = add i32 %join_c_0, %res_c +; CHECK: add3 (M1_NM, 1) join_d_1(0,0)<1> waveAllJoint(0,9)<0;1,0> waveAllJoint(0,10)<0;1,0> waveAllJoint(0,11)<0;1,0> + %join_d_0 = add i32 %res_add_3, %res_shl_3 + %join_d_1 = add i32 %join_d_0, %res_d +; CHECK: add (M1_NM, 1) join_e_0(0,0)<1> waveAllJoint(0,12)<0;1,0> waveAllJoint(0,13)<0;1,0> + %join_e_0 = add i32 %res_add_4, %res_shl_4 + %join_e_1 = add i32 %join_e_0, %res_e +; CHECK: add3 (M1_NM, 1) join_f_1(0,0)<1> waveAllJoint(0,15)<0;1,0> waveAllJoint(1,0)<0;1,0> res_f(0,0)<0;1,0> + %join_f_0 = add i32 %res_add_5, %res_shl_5 + %join_f_1 = add i32 %join_f_0, %res_f +; CHECK: add3 (M1_NM, 1) join_ab(0,0)<1> join_a_0(0,0)<0;1,0> waveAllJoint(0,2)<0;1,0> join_b_1(0,0)<0;1,0> + %join_ab = add i32 %join_a_1, %join_b_1 +; CHECK: add3 (M1_NM, 1) join_cd(0,0)<1> join_c_0(0,0)<0;1,0> waveAllJoint(0,8)<0;1,0> join_d_1(0,0)<0;1,0> + %join_cd = add i32 %join_c_1, %join_d_1 +; CHECK: add3 (M1_NM, 1) join_ef(0,0)<1> join_e_0(0,0)<0;1,0> waveAllJoint(0,14)<0;1,0> join_f_1(0,0)<0;1,0> + %join_ef = add i32 %join_e_1, %join_f_1 + %join_ab_cd = add i32 %join_ab, %join_cd +; CHECK: add3 (M1_NM, 1) join_ab_cd_ef(0,0)<1> join_ab(0,0)<0;1,0> join_cd(0,0)<0;1,0> join_ef(0,0)<0;1,0> + %join_ab_cd_ef = add i32 %join_ab_cd, %join_ef + %store = insertelement <1 x i32> undef, i32 %join_ab_cd_ef, i64 0 + call void @llvm.genx.GenISA.storerawvector.indexed.p2490377v4f32.v1i32(<4 x float> addrspace(2490369)* %dst, i32 0, <1 x i32> %store, i32 4, i1 false) + ret void +} + +declare i16 @llvm.genx.GenISA.simdLaneId() #1 + +declare i32 @llvm.genx.GenISA.ldraw.indexed.i32.p2490368v4f32(<4 x float> addrspace(2490368)*, i32, i32, i1) #2 + +declare <17 x i32> @llvm.genx.GenISA.WaveAll.v17i32.i8.i32(<17 x i32>, i8, i32) #3 +declare <1 x i32> @llvm.genx.GenISA.WaveAll.v1i32.i8.i32(<1 x i32>, i8, i32) #3 +declare i32 @llvm.genx.GenISA.WaveAll.i32.i8.i32(i32, i8, i32) #3 + +declare i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32, i32, i32) #4 + +declare void @llvm.genx.GenISA.storerawvector.indexed.p2490377v4f32.v1i32(<4 x float> addrspace(2490369)*, i32, <1 x i32>, i32, i1) #5 + +attributes #0 = { null_pointer_is_valid } +attributes #1 = { nounwind readnone } +attributes #2 = { argmemonly nounwind readonly } +attributes #3 = { convergent inaccessiblememonly nounwind } +attributes #4 = { convergent nounwind readnone } +attributes #5 = { argmemonly nounwind writeonly } + +!igc.functions = !{!0} +!IGCMetadata = !{!3} + +!0 = !{void (i32, i32, i32)* @CSMain, !1} +!1 = !{!2} +!2 = !{!"function_type", i32 0} +!3 = !{!"ModuleMD", !4} +!4 = !{!"FuncMD", !5, !6} +!5 = !{!"FuncMDMap[0]", void (i32, i32, i32)* @CSMain} +!6 = !{!"FuncMDValue[0]"} \ No newline at end of file diff --git a/IGC/Compiler/tests/WaveAllJointReduction/basic.ll b/IGC/Compiler/tests/WaveAllJointReduction/basic.ll new file mode 100644 index 000000000000..e7edc9192c53 --- /dev/null +++ b/IGC/Compiler/tests/WaveAllJointReduction/basic.ll @@ -0,0 +1,58 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2024 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= +; RUN: igc_opt -igc-wave-all-joint-reduction -S < %s | FileCheck %s +; ------------------------------------------------ +; WaveAllJointReduction: merge consecutive independent WaveAll operations into a single WaveAll joint operation +; ------------------------------------------------ + +define void @test_wave_all_joint_reduction(i32* %dst, i32 %a, i32 %b, i32 %c, i32 %d, i32 %e, i32 %f, i32 %g, i32 %h) { +; CHECK: [[IN_A:%.*]] = insertelement <8 x i32> undef, i32 %a, i64 0 +; CHECK-NEXT: [[IN_AB:%.*]] = insertelement <8 x i32> [[IN_A]], i32 %b, i64 1 +; CHECK-NEXT: [[IN_ABC:%.*]] = insertelement <8 x i32> [[IN_AB]], i32 %c, i64 2 +; CHECK-NEXT: [[IN_ABCD:%.*]] = insertelement <8 x i32> [[IN_ABC]], i32 %d, i64 3 +; CHECK-NEXT: [[IN_ABCDE:%.*]] = insertelement <8 x i32> [[IN_ABCD]], i32 %e, i64 4 +; CHECK-NEXT: [[IN_ABCDEF:%.*]] = insertelement <8 x i32> [[IN_ABCDE]], i32 %f, i64 5 +; CHECK-NEXT: [[IN_ABCDEFG:%.*]] = insertelement <8 x i32> [[IN_ABCDEF]], i32 %g, i64 6 +; CHECK-NEXT: [[IN_ABCDEFGH:%.*]] = insertelement <8 x i32> [[IN_ABCDEFG]], i32 %h, i64 7 +; CHECK-NEXT: [[WAVE_ALL:%.*]] = call <8 x i32> @llvm.genx.GenISA.WaveAll.v8i32.i8.i32(<8 x i32> [[IN_ABCDEFGH]], i8 0, i32 0) +; CHECK-NOT: call i32 @llvm.genx.GenISA.WaveAll.i32 + %res_a = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %a, i8 0, i32 0) + %res_b = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %b, i8 0, i32 0) + %res_c = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %c, i8 0, i32 0) + %res_d = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %d, i8 0, i32 0) + %res_e = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %e, i8 0, i32 0) + %res_f = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %f, i8 0, i32 0) + %res_g = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %g, i8 0, i32 0) + %res_h = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %h, i8 0, i32 0) +; CHECK: [[RES_A:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 0 +; CHECK-NEXT: [[RES_B:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 1 +; CHECK-NEXT: [[RES_C:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 2 +; CHECK-NEXT: [[RES_D:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 3 +; CHECK-NEXT: [[RES_E:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 4 +; CHECK-NEXT: [[RES_F:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 5 +; CHECK-NEXT: [[RES_G:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 6 +; CHECK-NEXT: [[RES_H:%.*]] = extractelement <8 x i32> [[WAVE_ALL]], i64 7 +; CHECK: %join_a_b = add i32 [[RES_A]], [[RES_B]] + %join_a_b = add i32 %res_a, %res_b +; CHECK: %join_c_d = add i32 [[RES_C]], [[RES_D]] + %join_c_d = add i32 %res_c, %res_d +; CHECK: %join_e_f = add i32 [[RES_E]], [[RES_F]] + %join_e_f = add i32 %res_e, %res_f +; CHECK: %join_g_h = add i32 [[RES_G]], [[RES_H]] + %join_g_h = add i32 %res_g, %res_h + %join_a_b_c_d = add i32 %join_a_b, %join_c_d + %join_e_f_g_h = add i32 %join_e_f, %join_g_h + %join_a_b_c_d_e_f_g_h = add i32 %join_a_b_c_d, %join_e_f_g_h + store i32 %join_a_b_c_d_e_f_g_h, i32* %dst + ret void +} + +; Function Attrs: convergent inaccessiblememonly nounwind +declare i32 @llvm.genx.GenISA.WaveAll.i32(i32, i8, i32) #0 + +attributes #0 = { convergent inaccessiblememonly nounwind } diff --git a/IGC/Compiler/tests/WaveAllJointReduction/separated-groups.ll b/IGC/Compiler/tests/WaveAllJointReduction/separated-groups.ll new file mode 100644 index 000000000000..d2f598296711 --- /dev/null +++ b/IGC/Compiler/tests/WaveAllJointReduction/separated-groups.ll @@ -0,0 +1,61 @@ +;=========================== begin_copyright_notice ============================ +; +; Copyright (C) 2024 Intel Corporation +; +; SPDX-License-Identifier: MIT +; +;============================ end_copyright_notice ============================= +; RUN: igc_opt -igc-wave-all-joint-reduction -S < %s | FileCheck %s +; ------------------------------------------------ +; WaveAllJointReduction: merge and group independent WaveAll operations into two WaveAll joint operations +; ------------------------------------------------ + +define void @test_wave_all_joint_reduction(i32* %dst, i32 %a, i32 %b, i32 %c, i32 %d, i32 %e, i32 %f, i32 %g, i32 %h) { +; CHECK: [[IN_A:%.*]] = insertelement <3 x i32> undef, i32 %a, i64 0 +; CHECK-NEXT: [[IN_AB:%.*]] = insertelement <3 x i32> [[IN_A]], i32 %b, i64 1 +; CHECK-NEXT: [[IN_ABC:%.*]] = insertelement <3 x i32> [[IN_AB]], i32 %c, i64 2 +; CHECK-NEXT: [[WAVE_ALL_ABC:%.*]] = call <3 x i32> @llvm.genx.GenISA.WaveAll.v3i32.i8.i32(<3 x i32> [[IN_ABC]], i8 0, i32 0) +; CHECK-NOT: call i32 @llvm.genx.GenISA.WaveAll.i32 + %res_a = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %a, i8 0, i32 0) + %res_b = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %b, i8 0, i32 0) + %res_c = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %c, i8 0, i32 0) +; CHECK: [[RES_A:%.*]] = extractelement <3 x i32> [[WAVE_ALL_ABC]], i64 0 +; CHECK-NEXT: [[RES_B:%.*]] = extractelement <3 x i32> [[WAVE_ALL_ABC]], i64 1 +; CHECK-NEXT: [[RES_C:%.*]] = extractelement <3 x i32> [[WAVE_ALL_ABC]], i64 2 + %separator = add i32 %a, %b +; CHECK: [[IN_D:%.*]] = insertelement <5 x i32> undef, i32 %d, i64 0 +; CHECK-NEXT: [[IN_DE:%.*]] = insertelement <5 x i32> [[IN_D]], i32 %e, i64 1 +; CHECK-NEXT: [[IN_DEF:%.*]] = insertelement <5 x i32> [[IN_DE]], i32 %f, i64 2 +; CHECK-NEXT: [[IN_DEFG:%.*]] = insertelement <5 x i32> [[IN_DEF]], i32 %g, i64 3 +; CHECK-NEXT: [[IN_DEFGH:%.*]] = insertelement <5 x i32> [[IN_DEFG]], i32 %h, i64 4 +; CHECK-NEXT: [[WAVE_ALL_DEFGH:%.*]] = call <5 x i32> @llvm.genx.GenISA.WaveAll.v5i32.i8.i32(<5 x i32> [[IN_DEFGH]], i8 0, i32 0) +; CHECK-NOT: call i32 @llvm.genx.GenISA.WaveAll.i32 + %res_d = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %d, i8 0, i32 0) + %res_e = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %e, i8 0, i32 0) + %res_f = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %f, i8 0, i32 0) + %res_g = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %g, i8 0, i32 0) + %res_h = call i32 @llvm.genx.GenISA.WaveAll.i32(i32 %h, i8 0, i32 0) +; CHECK: [[RES_D:%.*]] = extractelement <5 x i32> [[WAVE_ALL_DEFGH]], i64 0 +; CHECK-NEXT: [[RES_E:%.*]] = extractelement <5 x i32> [[WAVE_ALL_DEFGH]], i64 1 +; CHECK-NEXT: [[RES_F:%.*]] = extractelement <5 x i32> [[WAVE_ALL_DEFGH]], i64 2 +; CHECK-NEXT: [[RES_G:%.*]] = extractelement <5 x i32> [[WAVE_ALL_DEFGH]], i64 3 +; CHECK-NEXT: [[RES_H:%.*]] = extractelement <5 x i32> [[WAVE_ALL_DEFGH]], i64 4 +; CHECK: %join_a_b = add i32 [[RES_A]], [[RES_B]] + %join_a_b = add i32 %res_a, %res_b +; CHECK: %join_c_d = add i32 [[RES_C]], [[RES_D]] + %join_c_d = add i32 %res_c, %res_d +; CHECK: %join_e_f = add i32 [[RES_E]], [[RES_F]] + %join_e_f = add i32 %res_e, %res_f +; CHECK: %join_g_h = add i32 [[RES_G]], [[RES_H]] + %join_g_h = add i32 %res_g, %res_h + %join_a_b_c_d = add i32 %join_a_b, %join_c_d + %join_e_f_g_h = add i32 %join_e_f, %join_g_h + %join_a_b_c_d_e_f_g_h = add i32 %join_a_b_c_d, %join_e_f_g_h + store i32 %join_a_b_c_d_e_f_g_h, i32* %dst + ret void +} + +; Function Attrs: convergent inaccessiblememonly nounwind +declare i32 @llvm.genx.GenISA.WaveAll.i32(i32, i8, i32) #0 + +attributes #0 = { convergent inaccessiblememonly nounwind } diff --git a/IGC/GenISAIntrinsics/GenIntrinsicInst.h b/IGC/GenISAIntrinsics/GenIntrinsicInst.h index e74fff997a4a..f86b5f865aef 100644 --- a/IGC/GenISAIntrinsics/GenIntrinsicInst.h +++ b/IGC/GenISAIntrinsics/GenIntrinsicInst.h @@ -1246,6 +1246,24 @@ class QuadPrefixIntrinsic : public GenIntrinsicInst } }; +class WaveAllIntrinsic : public GenIntrinsicInst +{ +public: + Value *getSrc() const { return getOperand(0); } + IGC::WaveOps getOpKind() const + { + return static_cast(getImm64Operand(1)); + } + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const GenIntrinsicInst *I) { + return I->getIntrinsicID() == GenISAIntrinsic::GenISA_WaveAll; + } + static inline bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } +}; + // This is just a meta intrinsic that encapsulates the idea of intrinsics // that contain continuation IDs. class ContinuationHLIntrinsic : public GenIntrinsicInst { diff --git a/IGC/common/igc_flags.h b/IGC/common/igc_flags.h index bf508fde109b..72bddcc985fc 100644 --- a/IGC/common/igc_flags.h +++ b/IGC/common/igc_flags.h @@ -323,6 +323,7 @@ DECLARE_IGC_REGKEY(bool, DisableLoopSplitWidePHIs, false, "Disable splitting of DECLARE_IGC_REGKEY(bool, EnableBarrierControlFlowOptimizationPass, false, "Enable barrier control flow optimization pass", false) DECLARE_IGC_REGKEY(bool, EnableWaveShuffleIndexSinking, false, "Hoist identical instructions operating on WaveShuffleIndex instructions with the same source and a constant lane/channel", false) DECLARE_IGC_REGKEY(DWORD, WaveShuffleIndexSinkingMaxIterations, 3, "Max number of iterations to run iterative WaveShuffleIndexSinking", false) +DECLARE_IGC_REGKEY(bool, EnableWaveAllJointReduction, false, "Enable Joint Reduction Optimization.", false) DECLARE_IGC_GROUP("Shader debugging") DECLARE_IGC_REGKEY(bool, CopyA0ToDBG0, false, " Copy a0 used for extended msg descriptor to dbg0 to help debug", false)