Skip to content

Commit

Permalink
sub group clustered ballot
Browse files Browse the repository at this point in the history
Adds new intrinsic: sub group clustered ballot. Works similar to sub
group ballot, but each lane contains results only from its' cluster.
Only cluster sizes 8 and 16 are supported.
  • Loading branch information
pkwasnie-intel authored and igcbot committed Dec 3, 2024
1 parent ba264a3 commit 57461d9
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 2 deletions.
1 change: 1 addition & 0 deletions IGC/BiFModule/Implementation/IGCBiF_Intrinsics.cl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ uint __builtin_IB_get_image_bti(uint img);

// ballot intrinsic
uint __builtin_IB_WaveBallot(bool p);
uint __builtin_IB_clustered_WaveBallot(bool p, uint cluster_size);

// VA
void __builtin_IB_va_erode_64x4( __local uchar* dst, float2 coords, int srcImgId, int i_accelerator );
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/CheckInstrTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ void CheckInstrTypes::visitCallInst(CallInst& C)
case GenISAIntrinsic::GenISA_WaveBallot:
case GenISAIntrinsic::GenISA_wavebarrier:
case GenISAIntrinsic::GenISA_WaveInverseBallot:
case GenISAIntrinsic::GenISA_WaveClusteredBallot:
case GenISAIntrinsic::GenISA_WavePrefix:
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/CodeSinking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2492,6 +2492,7 @@ namespace IGC {
case GenISAIntrinsic::GenISA_WaveClusteredBroadcast:
case GenISAIntrinsic::GenISA_WaveBallot:
case GenISAIntrinsic::GenISA_WaveInverseBallot:
case GenISAIntrinsic::GenISA_WaveClusteredBallot:
case GenISAIntrinsic::GenISA_WaveAll:
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
Expand Down
88 changes: 86 additions & 2 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9173,6 +9173,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
case GenISAIntrinsic::GenISA_WaveInverseBallot:
emitWaveInverseBallot(inst);
break;
case GenISAIntrinsic::GenISA_WaveClusteredBallot:
emitWaveClusteredBallot(inst);
break;
case GenISAIntrinsic::GenISA_WaveShuffleIndex:
case GenISAIntrinsic::GenISA_WaveBroadcast:
emitSimdShuffle(inst);
Expand Down Expand Up @@ -21551,6 +21554,23 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
destination = m_currShader->GetNewVariable(1, ISA_TYPE_UD, EALIGN_GRF, true, CName::NONE);
}

emitBallotUniform(inst, destination, disableHelperLanes);

if (destination != m_destination)
{
m_encoder->Cast(m_destination, destination);
m_encoder->Push();
}
if (disableHelperLanes)
{
ResetVMask();
}
}

void EmitPass::emitBallotUniform(llvm::GenIntrinsicInst* inst, CVariable* destination, bool disableHelperLanes)
{
IGC_ASSERT_MESSAGE(destination->IsUniform(), "Unsupported: dst must be uniform");

bool uniform_active_lane = false;
if (ConstantInt * pConst = dyn_cast<ConstantInt>(inst->getOperand(0)))
{
Expand Down Expand Up @@ -21604,12 +21624,76 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
m_encoder->Push();
}
}
}

if (destination != m_destination)
void EmitPass::emitWaveClusteredBallot(llvm::GenIntrinsicInst* inst)
{
IGC_ASSERT_MESSAGE(!m_destination->IsUniform(), "Unsupported: dst must be non-uniform");

IGC_ASSERT_MESSAGE(isa<llvm::ConstantInt>(inst->getOperand(1)), "Unsupported: cluster size must be constant");
const unsigned int clusterSize = int_cast<uint32_t>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());

IGC_ASSERT_MESSAGE(clusterSize < numLanes(m_currShader->m_dispatchSize), "cluster size must be smaller than SIMD");
IGC_ASSERT_MESSAGE(clusterSize == 8 || clusterSize == 16, "cluster size must be 8 or 16");

bool disableHelperLanes = int_cast<int>(cast<ConstantInt>(inst->getArgOperand(2))->getSExtValue()) == 2;
if (disableHelperLanes)
{
m_encoder->Cast(m_destination, destination);
ForceDMask();
}

// Run ballot.
CVariable* ballotResult = m_currShader->GetNewVariable(1, ISA_TYPE_UD, EALIGN_GRF, true, "ballotResult");
emitBallotUniform(inst, ballotResult, disableHelperLanes);

// ballotResult contains result from all lanes. Cluster can be either 8 or 16 lanes, so clusters in
// ballotResult are byte-aligned. Extract clusters from the result.

CVariable* zero = m_currShader->ImmToVariable(0, ISA_TYPE_UD);
m_encoder->Copy(m_destination, zero);
if (m_currShader->m_numberInstance > 1)
{
m_encoder->SetSecondHalf(true);
m_encoder->Copy(m_destination, zero);
m_encoder->SetSecondHalf(false);
}
m_encoder->Push();

if (clusterSize == 8)
{
CVariable* ballotAlias = m_currShader->GetNewAlias(ballotResult, ISA_TYPE_B, 0, 4, false);
CVariable* dstAlias = m_currShader->GetNewAlias(m_destination, ISA_TYPE_B, 0, numLanes(m_currShader->m_SIMDSize) * 4);

m_encoder->SetSrcRegion(0, 1, 8, 0);
m_encoder->SetDstRegion(4);
m_encoder->Copy(dstAlias, ballotAlias);
if (m_currShader->m_numberInstance > 1)
{
m_encoder->SetSecondHalf(true);
m_encoder->SetSrcSubReg(0, 2);
m_encoder->Copy(dstAlias, ballotAlias);
m_encoder->SetSecondHalf(false);
}
m_encoder->Push();
}
else if (clusterSize == 16)
{
CVariable* ballotAlias = m_currShader->GetNewAlias(ballotResult, ISA_TYPE_UW, 0, 2, false);
CVariable* dstAlias = m_currShader->GetNewAlias(m_destination, ISA_TYPE_UW, 0, numLanes(m_currShader->m_SIMDSize) * 2);

m_encoder->SetSrcRegion(0, 1, 16, 0);
m_encoder->SetDstRegion(2);
m_encoder->Copy(dstAlias, ballotAlias);
if (m_currShader->m_numberInstance > 1)
{
m_encoder->SetSecondHalf(true);
m_encoder->SetSrcSubReg(0, 1);
m_encoder->Copy(dstAlias, ballotAlias);
m_encoder->SetSecondHalf(false);
}
m_encoder->Push();
}

if (disableHelperLanes)
{
ResetVMask();
Expand Down
2 changes: 2 additions & 0 deletions IGC/Compiler/CISACodeGen/EmitVISAPass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ class EmitPass : public llvm::FunctionPass

// CrossLane Instructions
void emitWaveBallot(llvm::GenIntrinsicInst* inst);
void emitWaveClusteredBallot(llvm::GenIntrinsicInst* inst);
void emitBallotUniform(llvm::GenIntrinsicInst* inst, CVariable* destination, bool disableHelperLanes);
void emitWaveInverseBallot(llvm::GenIntrinsicInst* inst);
void emitWaveShuffleIndex(llvm::GenIntrinsicInst* inst);
void emitWavePrefix(llvm::WavePrefixIntrinsic* I);
Expand Down
2 changes: 2 additions & 0 deletions IGC/Compiler/CISACodeGen/PatternMatchPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,7 @@ namespace IGC
break;
case GenISAIntrinsic::GenISA_WaveBallot:
case GenISAIntrinsic::GenISA_WaveInverseBallot:
case GenISAIntrinsic::GenISA_WaveClusteredBallot:
case GenISAIntrinsic::GenISA_WaveAll:
case GenISAIntrinsic::GenISA_WaveClustered:
case GenISAIntrinsic::GenISA_WaveInterleave:
Expand Down Expand Up @@ -5226,6 +5227,7 @@ namespace IGC
switch (I.getIntrinsicID())
{
case GenISAIntrinsic::GenISA_WaveAll:
case GenISAIntrinsic::GenISA_WaveClusteredBallot:
helperLaneIndex = 2;
break;
case GenISAIntrinsic::GenISA_WaveBallot:
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/WIAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,7 @@ WIAnalysis::WIDependancy WIAnalysisRunner::calculate_dep(const CallInst* inst)
intrinsic_name == llvm_waveBroadcast ||
intrinsic_name == llvm_waveClusteredBroadcast ||
intrinsic_name == llvm_waveBallot ||
intrinsic_name == llvm_waveClusteredBallot ||
intrinsic_name == llvm_waveAll ||
intrinsic_name == llvm_waveClustered ||
intrinsic_name == llvm_waveInterleave ||
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,7 @@ namespace IGC
opcode == llvm_waveBroadcast ||
opcode == llvm_waveClusteredBroadcast ||
opcode == llvm_waveBallot ||
opcode == llvm_waveClusteredBallot ||
opcode == llvm_simdShuffleDown ||
opcode == llvm_simdBlockRead||
opcode == llvm_simdBlockReadBindless);
Expand Down
1 change: 1 addition & 0 deletions IGC/Compiler/CISACodeGen/opCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ DECLARE_OPCODE(GenISA_pair_to_ptr, GenISAIntrinsic, llvm_pair_to_ptr, false, fal

// Wave intrinsics
DECLARE_OPCODE(GenISA_WaveBallot, GenISAIntrinsic, llvm_waveBallot, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveClusteredBallot, GenISAIntrinsic, llvm_waveClusteredBallot, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveAll, GenISAIntrinsic, llvm_waveAll, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveClustered, GenISAIntrinsic, llvm_waveClustered, false, false, false, false, false, false, false)
DECLARE_OPCODE(GenISA_WaveInterleave, GenISAIntrinsic, llvm_waveInterleave, false, false, false, false, false, false, false)
Expand Down
7 changes: 7 additions & 0 deletions IGC/Compiler/Optimizer/OCLBIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,6 +1206,12 @@ class CWaveBallotIntrinsic : public CCommand
}

m_args.push_back(truncInst);

if (isaId == GenISAIntrinsic::GenISA_WaveClusteredBallot)
{
m_args.push_back(m_pCallInst->getArgOperand(1));
}

m_args.push_back(IRB.getInt32(0));
replaceGenISACallInst(isaId);
}
Expand Down Expand Up @@ -1761,6 +1767,7 @@ CBuiltinsResolver::CBuiltinsResolver(CImagesBI::ParamMap* paramMap, CImagesBI::I

// Ballot builtins
m_CommandMap["__builtin_IB_WaveBallot"] = CWaveBallotIntrinsic::create(GenISAIntrinsic::GenISA_WaveBallot);
m_CommandMap["__builtin_IB_clustered_WaveBallot"] = CWaveBallotIntrinsic::create(GenISAIntrinsic::GenISA_WaveClusteredBallot);

m_CommandMap[StringRef("__builtin_IB_samplepos")] = CSamplePos::create();

Expand Down
28 changes: 28 additions & 0 deletions IGC/GenISAIntrinsics/generator/input/Intrinsic_definitions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,34 @@ intrinsics:
memory_effects:
- !<MemoryRestriction>
memory_location: !MemoryLocation InaccessibleMem
- !<IntrinsicDefinition>
name: "GenISA_WaveClusteredBallot"
comment: "Same as WaveBallot, but result includes only lanes for current cluster.\
\ Works in non-uniform context."
return_definition: !<ReturnDefinition>
type_definition: *i32
comment: "return a bitfield with 1 for active lane in cluster with input true,\
\ 0 for the rest."
arguments:
- !<ArgumentDefinition>
name: Arg0
type_definition: *i1
comment: "predicate"
- !<ArgumentDefinition>
name: Arg1
type_definition: *i32
comment: "cluster size - must be a compile time constant 8 or 16"
- !<ArgumentDefinition>
name: Arg2
type_definition: *i32
comment: "helperLaneMode : 0: not used; 1: helper lanes participatein\
\ wave ops, 2: helper lanes do not participate in wave ops."
attributes:
- !AttributeID "Convergent"
- !AttributeID "NoUnwind"
memory_effects:
- !<MemoryRestriction>
memory_location: !MemoryLocation InaccessibleMem
- !<IntrinsicDefinition>
name: "GenISA_WaveClustered"
comment: "Accumulate all active lanes within consecutive input clusters and\
Expand Down

0 comments on commit 57461d9

Please sign in to comment.