From 21415efe4e4374a1acb676b261d9ddd473df6fd0 Mon Sep 17 00:00:00 2001 From: Rex Xu Date: Mon, 18 Sep 2023 15:14:58 +0800 Subject: [PATCH] Add support for SW primitive statistics counting For pre-GFX11, the primitive statistics counting is performed by HW via the same mechanism of transform feedback. For GFX11+, since transform feedback is done by SW emulation, the primitive statistics counting will follow the same handling. We add a new handler collectPrimitiveStats to deal with it. It is a reduced version of SW transform feedback, only updating HW counters of requested vertex streams. We don't merge it with SW transform feedback because this will make the logic blurry. Indeed, some duplicated codes are the trade-off. For non-GS case, only the counter of stream 0 will be updated and generated primitive count is passed by GE (we don't modify it). For GS case, counters of all active vertex streams will be updated and we must calculate generated primitive count first before doing such update. The calculation is to count valid primitive mask bits in this NGG subgroup and add them together. The implementation of this PR is the foundation of VK_EXT_primitives_generated_query. --- lgc/builder/MiscBuilder.cpp | 12 +- lgc/include/lgc/state/PipelineState.h | 2 +- lgc/patch/NggPrimShader.cpp | 367 ++++++++++++++++++++++---- lgc/patch/NggPrimShader.h | 8 +- 4 files changed, 327 insertions(+), 62 deletions(-) diff --git a/lgc/builder/MiscBuilder.cpp b/lgc/builder/MiscBuilder.cpp index 3aaa0c74df..a84f6a0bc1 100644 --- a/lgc/builder/MiscBuilder.cpp +++ b/lgc/builder/MiscBuilder.cpp @@ -46,8 +46,10 @@ using namespace llvm; Instruction *BuilderImpl::CreateEmitVertex(unsigned streamId) { assert(m_shaderStage == ShaderStageGeometry); - // Mark this vertex stream as active if transform feedback is enabled or this is the rasterization stream. - if (m_pipelineState->enableXfb() || m_pipelineState->getRasterizerState().rasterStream == streamId) + // Mark this vertex stream as active if transform feedback is enabled, or primitive statistics counting is enabled, + // or this is the rasterization stream. + if (m_pipelineState->enableXfb() || m_pipelineState->enablePrimStats() || + m_pipelineState->getRasterizerState().rasterStream == streamId) m_pipelineState->setVertexStreamActive(streamId); // Get GsWaveId @@ -68,8 +70,10 @@ Instruction *BuilderImpl::CreateEmitVertex(unsigned streamId) { Instruction *BuilderImpl::CreateEndPrimitive(unsigned streamId) { assert(m_shaderStage == ShaderStageGeometry); - // Mark this vertex stream as active if transform feedback is enabled or this is the rasterization stream. - if (m_pipelineState->enableXfb() || m_pipelineState->getRasterizerState().rasterStream == streamId) + // Mark this vertex stream as active if transform feedback is enabled, or primitive statistics counting is enabled, + // or this is the rasterization stream. + if (m_pipelineState->enableXfb() || m_pipelineState->enablePrimStats() || + m_pipelineState->getRasterizerState().rasterStream == streamId) m_pipelineState->setVertexStreamActive(streamId); // Get GsWaveId diff --git a/lgc/include/lgc/state/PipelineState.h b/lgc/include/lgc/state/PipelineState.h index 7b3ab0a3f9..6b521c2dcf 100644 --- a/lgc/include/lgc/state/PipelineState.h +++ b/lgc/include/lgc/state/PipelineState.h @@ -396,7 +396,7 @@ class PipelineState final : public Pipeline { // Check if transform feedback is active bool enableXfb() const { return m_xfbStateMetadata.enableXfb; } - // Check if we need count primitives if XFB is disabled + // Check if we need primitive statistics counting bool enablePrimStats() const { return m_xfbStateMetadata.enablePrimStats; } // Get transform feedback strides diff --git a/lgc/patch/NggPrimShader.cpp b/lgc/patch/NggPrimShader.cpp index 285a559649..cd864da182 100644 --- a/lgc/patch/NggPrimShader.cpp +++ b/lgc/patch/NggPrimShader.cpp @@ -242,7 +242,7 @@ PrimShaderLdsUsageInfo NggPrimShader::layoutPrimShaderLds(PipelineState *pipelin ldsUsageInfo.gsExtraLdsSize += ldsRegionSize; // Primitive counts - if (pipelineState->enableSwXfb()) { + if (pipelineState->enableSwXfb() || pipelineState->enablePrimStats()) { ldsRegionSize = (Gfx9::NggMaxWavesPerSubgroup + 1) * MaxGsStreams; // 1 dword per wave and 1 dword per subgroup, 4 GS streams if (ldsLayout) { @@ -265,9 +265,10 @@ PrimShaderLdsUsageInfo NggPrimShader::layoutPrimShaderLds(PipelineState *pipelin } // Vertex counts - if (pipelineState->enableSwXfb()) { + if (pipelineState->enableSwXfb() || pipelineState->enablePrimStats()) { if (ldsLayout) { - // NOTE: If SW emulated stream-out is enabled, this region is overlapped with PrimitiveCounts + // NOTE: If SW emulated stream-out or primitive statistics counting is enabled, this region is overlapped with + // PrimitiveCounts. (*ldsLayout)[PrimShaderLdsRegion::VertexCounts] = (*ldsLayout)[PrimShaderLdsRegion::PrimitiveCounts]; printLdsRegionInfo("Vertex Counts", (*ldsLayout)[PrimShaderLdsRegion::VertexCounts].first, (*ldsLayout)[PrimShaderLdsRegion::VertexCounts].second); @@ -287,7 +288,7 @@ PrimShaderLdsUsageInfo NggPrimShader::layoutPrimShaderLds(PipelineState *pipelin if (pipelineState->getNggControl()->compactVertex) { if (pipelineState->enableSwXfb()) { if (ldsLayout) { - // NOTE: If SW emulated stream-out is enabled, this region is overlapped with PrimitiveIndexMap + // NOTE: If SW emulated stream-out is enabled, this region is overlapped with PrimitiveIndexMap. (*ldsLayout)[PrimShaderLdsRegion::VertexIndexMap] = (*ldsLayout)[PrimShaderLdsRegion::PrimitiveIndexMap]; printLdsRegionInfo("Vertex Index Map (To Uncompacted)", (*ldsLayout)[PrimShaderLdsRegion::VertexIndexMap].first, @@ -513,9 +514,6 @@ Function *NggPrimShader::generate(Function *esMain, Function *gsMain, Function * // ES and GS could not be null at the same time assert((!esMain && !gsMain) == false); - // TODO: support counting generated primitives in software emulated stream-out - assert(!m_pipelineState->enablePrimStats()); - // Assign names to ES, GS and copy shader main functions Module *module = nullptr; if (esMain) { @@ -860,16 +858,18 @@ void NggPrimShader::buildPrimShaderCbLayoutLookupTable() { } // ===================================================================================================================== -// Calculate the dword offset of each item in the stream-out control buffer +// Calculate the dword offset of each item in the stream-out control buffer. void NggPrimShader::calcStreamOutControlCbOffsets() { - assert(m_pipelineState->enableSwXfb()); + assert(m_pipelineState->enableSwXfb() || m_pipelineState->enablePrimStats()); m_streamOutControlCbOffsets = {}; - for (unsigned i = 0; i < MaxTransformFeedbackBuffers; ++i) { - m_streamOutControlCbOffsets.bufOffsets[i] = (offsetof(Util::Abi::StreamOutControlCb, bufOffsets[0]) + - sizeof(Util::Abi::StreamOutControlCb::bufOffsets[0]) * i) / - 4; + if (m_pipelineState->enableSwXfb()) { + for (unsigned i = 0; i < MaxTransformFeedbackBuffers; ++i) { + m_streamOutControlCbOffsets.bufOffsets[i] = (offsetof(Util::Abi::StreamOutControlCb, bufOffsets[0]) + + sizeof(Util::Abi::StreamOutControlCb::bufOffsets[0]) * i) / + 4; + } } } @@ -936,6 +936,9 @@ void NggPrimShader::buildPassthroughPrimShader(Function *primShader) { // if (Enable SW XFB) // Process SW XFB (Run ES) // else { + // if (Enable primitive statistics counting) + // Collect primitive statistics + // // if (threadIdInSubgroup < vertCountInSubgroup) // Run ES (export vertex) // } @@ -964,7 +967,7 @@ void NggPrimShader::buildPassthroughPrimShader(Function *primShader) { // Record attribute ring base ([14:0]) m_nggInputs.attribRingBase = createUBfe(attribRingBase, 0, 15); - if (m_pipelineState->enableSwXfb()) + if (m_pipelineState->enableSwXfb() || m_pipelineState->enablePrimStats()) loadStreamOutBufferInfo(userData); } @@ -1035,8 +1038,11 @@ void NggPrimShader::buildPassthroughPrimShader(Function *primShader) { { m_builder.SetInsertPoint(endExportPrimitiveBlock); + // Process SW XFB or primitive statistics counting if (m_pipelineState->enableSwXfb()) processSwXfb(args); + else if (m_pipelineState->enablePrimStats()) + collectPrimitiveStats(); auto validVertex = m_builder.CreateICmpULT(m_nggInputs.threadIdInSubgroup, m_nggInputs.vertCountInSubgroup); m_builder.CreateCondBr(validVertex, exportVertexBlock, endExportVertexBlock); @@ -1159,6 +1165,8 @@ void NggPrimShader::buildPrimShader(Function *primShader) { // // if (Enable SW XFB) // Process SW XFB + // else if (Enable primitive statistics counting) + // Collect primitive statistics // // if (threadIdInWave < vertCountInWave) // Run part ES to fetch vertex cull data @@ -1287,7 +1295,7 @@ void NggPrimShader::buildPrimShader(Function *primShader) { // Record attribute ring base ([14:0]) m_nggInputs.attribRingBase = createUBfe(attribRingBase, 0, 15); - if (m_pipelineState->enableSwXfb()) + if (m_pipelineState->enableSwXfb() || m_pipelineState->enablePrimStats()) loadStreamOutBufferInfo(userData); } @@ -1308,9 +1316,11 @@ void NggPrimShader::buildPrimShader(Function *primShader) { // Distribute primitive ID if needed distributePrimitiveId(primitiveId); - // Process SW XFB + // Process SW XFB or primitive statistics counting if (m_pipelineState->enableSwXfb()) processSwXfb(args); + else if (m_pipelineState->enablePrimStats()) + collectPrimitiveStats(); m_builder.CreateBr(checkFetchVertexCullDataBlock); } @@ -1891,6 +1901,8 @@ void NggPrimShader::buildPrimShaderWithGs(Function *primShader) { // // if (Enable SW XFB) // Process SW XFB + // else if (Enable primitive statistics counting) + // Collect primitive statistics // // if (threadIdInSubgroup < waveCount + 1) // Initialize per-wave and per-subgroup count of output vertices @@ -1981,7 +1993,7 @@ void NggPrimShader::buildPrimShaderWithGs(Function *primShader) { // Record attribute ring base ([14:0]) m_nggInputs.attribRingBase = createUBfe(attribRingBase, 0, 15); - if (m_pipelineState->enableSwXfb()) + if (m_pipelineState->enableSwXfb() || m_pipelineState->enablePrimStats()) loadStreamOutBufferInfo(userData); } @@ -2046,8 +2058,11 @@ void NggPrimShader::buildPrimShaderWithGs(Function *primShader) { { m_builder.SetInsertPoint(endGsBlock); + // Process SW XFB or primitive statistics counting if (m_pipelineState->enableSwXfb()) processSwXfbWithGs(args); + else if (m_pipelineState->enablePrimStats()) + collectPrimitiveStats(); auto validWave = m_builder.CreateICmpULT(m_nggInputs.threadIdInSubgroup, m_builder.getInt32(waveCountInSubgroup + 1)); @@ -2468,7 +2483,8 @@ void NggPrimShader::initWaveThreadInfo(Value *mergedGroupInfo, Value *mergedWave // // @param userData : User data void NggPrimShader::loadStreamOutBufferInfo(Value *userData) { - assert(m_pipelineState->enableSwXfb()); // Must enable SW emulated stream-out + // Must enable SW emulated stream-out or primitive statistics counting + assert(m_pipelineState->enableSwXfb() || m_pipelineState->enablePrimStats()); calcStreamOutControlCbOffsets(); @@ -2489,7 +2505,18 @@ void NggPrimShader::loadStreamOutBufferInfo(Value *userData) { return userDataIndex; }; - // Get stream-out table pointer value and stream-out control buffer pointer value + // Helper to make a pointer from its integer address value and the type + auto makePointer = [&](Value *ptrValue, Type *ptrTy) { + Value *pc = m_builder.CreateIntrinsic(Intrinsic::amdgcn_s_getpc, {}, {}); + pc = m_builder.CreateBitCast(pc, FixedVectorType::get(m_builder.getInt32Ty(), 2)); + + Value *ptr = m_builder.CreateInsertElement(pc, ptrValue, static_cast(0)); + ptr = m_builder.CreateBitCast(ptr, m_builder.getInt64Ty()); + ptr = m_builder.CreateIntToPtr(ptr, ptrTy); + + return ptr; + }; + const auto gsOrEsMain = m_hasGs ? m_gsHandlers.main : m_esHandlers.main; StreamOutData streamOutData = {}; if (m_hasGs) @@ -2500,46 +2527,39 @@ void NggPrimShader::loadStreamOutBufferInfo(Value *userData) { streamOutData = m_pipelineState->getShaderInterfaceData(ShaderStageVertex)->entryArgIdxs.vs.streamOutData; assert(userData->getType()->isVectorTy()); - auto streamOutTablePtrValue = - m_builder.CreateExtractElement(userData, getUserDataIndex(gsOrEsMain, streamOutData.tablePtr)); + const auto constBufferPtrTy = PointerType::get(m_builder.getContext(), ADDR_SPACE_CONST); + + // Get stream-out control buffer pointer value auto streamOutControlBufPtrValue = m_builder.CreateExtractElement(userData, getUserDataIndex(gsOrEsMain, streamOutData.controlBufPtr)); + m_streamOutControlBufPtr = makePointer(streamOutControlBufPtrValue, constBufferPtrTy); - // Helper to make a pointer from its integer address value and the type - auto makePointer = [&](Value *ptrValue, Type *ptrTy) { - Value *pc = m_builder.CreateIntrinsic(Intrinsic::amdgcn_s_getpc, {}, {}); - pc = m_builder.CreateBitCast(pc, FixedVectorType::get(m_builder.getInt32Ty(), 2)); - - Value *ptr = m_builder.CreateInsertElement(pc, ptrValue, static_cast(0)); - ptr = m_builder.CreateBitCast(ptr, m_builder.getInt64Ty()); - ptr = m_builder.CreateIntToPtr(ptr, ptrTy); + if (m_pipelineState->enableSwXfb()) { + // Get stream-out table pointer value + auto streamOutTablePtrValue = + m_builder.CreateExtractElement(userData, getUserDataIndex(gsOrEsMain, streamOutData.tablePtr)); + auto streamOutTablePtr = makePointer(streamOutTablePtrValue, constBufferPtrTy); - return ptr; - }; + const auto &xfbStrides = m_pipelineState->getXfbBufferStrides(); + for (unsigned i = 0; i < MaxTransformFeedbackBuffers; ++i) { + bool bufferActive = xfbStrides[i] > 0; + if (!bufferActive) + continue; // Transform feedback buffer inactive - const auto constBufferPtrTy = PointerType::get(m_builder.getContext(), ADDR_SPACE_CONST); - auto streamOutTablePtr = makePointer(streamOutTablePtrValue, constBufferPtrTy); - m_streamOutControlBufPtr = makePointer(streamOutControlBufPtrValue, constBufferPtrTy); + // Get stream-out buffer descriptors and record them + m_streamOutBufDescs[i] = readValueFromCb(FixedVectorType::get(m_builder.getInt32Ty(), 4), streamOutTablePtr, + m_builder.getInt32(4 * i)); // <4 x i32> - const auto &xfbStrides = m_pipelineState->getXfbBufferStrides(); - for (unsigned i = 0; i < MaxTransformFeedbackBuffers; ++i) { - bool bufferActive = xfbStrides[i] > 0; - if (!bufferActive) - continue; // Transform feedback buffer inactive - - // Get stream-out buffer descriptors and record them - m_streamOutBufDescs[i] = readValueFromCb(FixedVectorType::get(m_builder.getInt32Ty(), 4), streamOutTablePtr, - m_builder.getInt32(4 * i)); // <4 x i32> - - // NOTE: PAL decided not to invalidate the SQC and L1 for every stream-out update, mainly because that will hurt - // overall performance worse than just forcing this one buffer to be read via L2. Since PAL would not have wider - // context, PAL believed that they would have to perform that invalidation on every Set/Load unconditionally. - // Thus, we force the load of stream-out control buffer to be volatile to let LLVM backend add GLC and DLC flags. - const bool isVolatile = m_gfxIp.major == 11; - // Get stream-out buffer offsets and record them - m_streamOutBufOffsets[i] = - readValueFromCb(m_builder.getInt32Ty(), m_streamOutControlBufPtr, - m_builder.getInt32(m_streamOutControlCbOffsets.bufOffsets[i]), isVolatile); // i32 + // NOTE: PAL decided not to invalidate the SQC and L1 for every stream-out update, mainly because that will hurt + // overall performance worse than just forcing this one buffer to be read via L2. Since PAL would not have wider + // context, PAL believed that they would have to perform that invalidation on every Set/Load unconditionally. + // Thus, we force the load of stream-out control buffer to be volatile to let LLVM backend add GLC and DLC flags. + const bool isVolatile = m_gfxIp.major == 11; + // Get stream-out buffer offsets and record them + m_streamOutBufOffsets[i] = + readValueFromCb(m_builder.getInt32Ty(), m_streamOutControlBufPtr, + m_builder.getInt32(m_streamOutControlCbOffsets.bufOffsets[i]), isVolatile); // i32 + } } } @@ -6218,7 +6238,7 @@ void NggPrimShader::processSwXfb(ArrayRef args) { // Calculate primsToWrite and dwordsToWrite // Increment GDS_STRMOUT_DWORDS_WRITTEN_X and release the control // Store XFB statistics info to LDS - // Increment GDS_STRMOUT_PRIMS_NEEDED_X and GDS_STRMOUT_PRIMS_WRITTEN_X + // Increment GDS_STRMOUT_PRIMS_NEEDED_0 and GDS_STRMOUT_PRIMS_WRITTEN_0 // } // Barrier // @@ -7514,6 +7534,245 @@ Value *NggPrimShader::fetchXfbOutput(Function *target, ArrayRef args return m_builder.CreateCall(xfbFetcher, xfbFetcherArgs); } +// ===================================================================================================================== +// Collect primitive statistics (primitive statistics counting) and update the values in HW counters. +void NggPrimShader::collectPrimitiveStats() { + // NOTE: For SW emulated stream-out, the processing will update HW counters at the same time unconditionally. We don't + // have to particularly call this function. + assert(!m_pipelineState->enableSwXfb()); + assert(m_pipelineState->enablePrimStats()); // Make sure we do need to count generated primitives + + if (!m_hasGs) { + // GS is not present + + // + // The processing is something like this: + // + // NGG_PRIM_STATS() { + // if (threadIdInSubgroup == 0) + // Increment GDS_STRMOUT_PRIMS_NEEDED_0 and GDS_STRMOUT_PRIMS_WRITTEN_0 + // } + // + BasicBlock *insertBlock = m_builder.GetInsertBlock(); + + BasicBlock *collectPrimitiveStatsBlock = createBlock(insertBlock->getParent(), ".collectPrimitiveStats"); + collectPrimitiveStatsBlock->moveAfter(insertBlock); + BasicBlock *endCollectPrimitiveStatsBlock = createBlock(insertBlock->getParent(), ".endCollectPrimitiveStats"); + endCollectPrimitiveStatsBlock->moveAfter(collectPrimitiveStatsBlock); + + // Insert branching in current block to collect primitive statistics + { + auto firstThreadInSubgroup = m_builder.CreateICmpEQ(m_nggInputs.threadIdInSubgroup, m_builder.getInt32(0)); + m_builder.CreateCondBr(firstThreadInSubgroup, collectPrimitiveStatsBlock, endCollectPrimitiveStatsBlock); + } + + // Construct ".collectPrimitiveStats" block + { + m_builder.SetInsertPoint(collectPrimitiveStatsBlock); + + if (m_gfxIp.major <= 11) { + m_builder.CreateIntrinsic(Intrinsic::amdgcn_ds_add_gs_reg_rtn, m_nggInputs.primCountInSubgroup->getType(), + {m_nggInputs.primCountInSubgroup, // value to add + m_builder.getInt32(GDS_STRMOUT_PRIMS_NEEDED_0 << 2)}); // count index + + m_builder.CreateIntrinsic(Intrinsic::amdgcn_ds_add_gs_reg_rtn, m_builder.getInt32Ty(), + {m_builder.getInt32(0), // value to add + m_builder.getInt32(GDS_STRMOUT_PRIMS_WRITTEN_0 << 2)}); // count index + } else { + llvm_unreachable("Not implemented!"); + } + + m_builder.CreateBr(endCollectPrimitiveStatsBlock); + } + + // Construct ".endCollectPrimitiveStats" block + { m_builder.SetInsertPoint(endCollectPrimitiveStatsBlock); } + + return; + } + + // GS is present + assert(m_hasGs); + + // + // The processing is something like this: + // + // NGG_GS_PRIM_STATS() { + // if (threadIdInSubgroup < MaxGsStreams) + // Initialize output primitive count for each vertex stream + // Barrier + // + // if (threadIdInSubgroup < primCountInSubgroup) + // Check the draw flag of output primitives and compute draw mask + // + // if (threadIdInWave == 0) + // Accumulate output primitive count + // Barrier + // + // if (threadIdInSubgroup == 0) { + // for each vertex stream + // Increment GDS_STRMOUT_PRIMS_NEEDED_X and GDS_STRMOUT_PRIMS_WRITTEN_X + // } + // } + // + BasicBlock *insertBlock = m_builder.GetInsertBlock(); + + BasicBlock *initPrimitiveCountsBlock = createBlock(insertBlock->getParent(), ".initPrimitiveCounts"); + initPrimitiveCountsBlock->moveAfter(insertBlock); + BasicBlock *endInitPrimitiveCountsBlock = createBlock(insertBlock->getParent(), ".endInitPrimitiveCounts"); + endInitPrimitiveCountsBlock->moveAfter(initPrimitiveCountsBlock); + + BasicBlock *checkPrimitiveDrawFlagBlock = createBlock(insertBlock->getParent(), ".checkPrimitiveDrawFlag"); + checkPrimitiveDrawFlagBlock->moveAfter(endInitPrimitiveCountsBlock); + BasicBlock *endCheckPrimitiveDrawFlagBlock = createBlock(insertBlock->getParent(), ".endCheckPrimitiveDrawFlag"); + endCheckPrimitiveDrawFlagBlock->moveAfter(checkPrimitiveDrawFlagBlock); + + BasicBlock *countPrimitivesBlock = createBlock(insertBlock->getParent(), ".countPrimitives"); + countPrimitivesBlock->moveAfter(endCheckPrimitiveDrawFlagBlock); + BasicBlock *endCountPrimitivesBlock = createBlock(insertBlock->getParent(), ".endCountPrimitives"); + endCountPrimitivesBlock->moveAfter(countPrimitivesBlock); + + BasicBlock *collectPrimitiveStatsBlock = createBlock(insertBlock->getParent(), ".collectPrimitiveStats"); + collectPrimitiveStatsBlock->moveAfter(endCountPrimitivesBlock); + BasicBlock *endCollectPrimitiveStatsBlock = createBlock(insertBlock->getParent(), ".endCollectPrimitiveStats"); + endCollectPrimitiveStatsBlock->moveAfter(collectPrimitiveStatsBlock); + + // Insert branching in current block to collect primitive statistics + { + auto validStream = m_builder.CreateICmpULT(m_nggInputs.threadIdInSubgroup, m_builder.getInt32(MaxGsStreams)); + m_builder.CreateCondBr(validStream, initPrimitiveCountsBlock, endInitPrimitiveCountsBlock); + } + + // Construct ".initPrimitiveCounts" block + { + m_builder.SetInsertPoint(initPrimitiveCountsBlock); + + writePerThreadDataToLds(m_builder.getInt32(0), m_nggInputs.threadIdInSubgroup, + PrimShaderLdsRegion::PrimitiveCounts); + + m_builder.CreateBr(endInitPrimitiveCountsBlock); + } + + // Construct ".endInitPrimitiveCounts" block + { + m_builder.SetInsertPoint(endInitPrimitiveCountsBlock); + + createFenceAndBarrier(); + + auto validPrimitive = m_builder.CreateICmpULT(m_nggInputs.threadIdInSubgroup, m_nggInputs.primCountInSubgroup); + m_builder.CreateCondBr(validPrimitive, checkPrimitiveDrawFlagBlock, endCheckPrimitiveDrawFlagBlock); + } + + // Construct ".checkPrimitiveDrawFlag" block + Value *drawFlag[MaxGsStreams] = {}; + { + m_builder.SetInsertPoint(checkPrimitiveDrawFlagBlock); + + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (m_pipelineState->isVertexStreamActive(i)) { + // drawFlag = primData[N] != NullPrim + auto primData = + readPerThreadDataFromLds(m_builder.getInt32Ty(), m_nggInputs.threadIdInSubgroup, + PrimShaderLdsRegion::PrimitiveData, Gfx9::NggMaxThreadsPerSubgroup * i); + drawFlag[i] = m_builder.CreateICmpNE(primData, m_builder.getInt32(NullPrim)); + } + } + + m_builder.CreateBr(endCheckPrimitiveDrawFlagBlock); + } + + // Construct ".endCheckPrimitiveDrawFlag" block + Value *drawMask[MaxGsStreams] = {}; + Value *primCountInWave[MaxGsStreams] = {}; + { + m_builder.SetInsertPoint(endCheckPrimitiveDrawFlagBlock); + + // Update draw flags + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (m_pipelineState->isVertexStreamActive(i)) { + drawFlag[i] = createPhi( + {{drawFlag[i], checkPrimitiveDrawFlagBlock}, {m_builder.getFalse(), endInitPrimitiveCountsBlock}}); + } + } + + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (m_pipelineState->isVertexStreamActive(i)) { + drawMask[i] = ballot(drawFlag[i]); + + primCountInWave[i] = m_builder.CreateUnaryIntrinsic(Intrinsic::ctpop, drawMask[i]); + primCountInWave[i] = m_builder.CreateTrunc(primCountInWave[i], m_builder.getInt32Ty()); + } + } + + auto firstThreadInWave = m_builder.CreateICmpEQ(m_nggInputs.threadIdInWave, m_builder.getInt32(0)); + m_builder.CreateCondBr(firstThreadInWave, countPrimitivesBlock, endCountPrimitivesBlock); + } + + // Construct ".countPrimitives" block + { + m_builder.SetInsertPoint(countPrimitivesBlock); + + unsigned regionStart = getLdsRegionStart(PrimShaderLdsRegion::PrimitiveCounts); + + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (m_pipelineState->isVertexStreamActive(i)) { + atomicAdd(primCountInWave[i], m_builder.getInt32(regionStart + i)); + } + } + + m_builder.CreateBr(endCountPrimitivesBlock); + } + + // Construct ".endCountPrimitives" block + Value *primCountInSubgroup[MaxGsStreams] = {}; + { + m_builder.SetInsertPoint(endCountPrimitivesBlock); + + createFenceAndBarrier(); + + auto primCountInStreams = readPerThreadDataFromLds(m_builder.getInt32Ty(), m_nggInputs.threadIdInWave, + PrimShaderLdsRegion::PrimitiveCounts); + + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (!m_pipelineState->isVertexStreamActive(i)) + continue; + + primCountInSubgroup[i] = m_builder.CreateIntrinsic(m_builder.getInt32Ty(), Intrinsic::amdgcn_readlane, + {primCountInStreams, m_builder.getInt32(i)}); + } + + auto firstThreadInSubgroup = m_builder.CreateICmpEQ(m_nggInputs.threadIdInSubgroup, m_builder.getInt32(0)); + m_builder.CreateCondBr(firstThreadInSubgroup, collectPrimitiveStatsBlock, endCollectPrimitiveStatsBlock); + } + + // Construct ".collectPrimitiveStats" block + { + m_builder.SetInsertPoint(collectPrimitiveStatsBlock); + + for (unsigned i = 0; i < MaxGsStreams; ++i) { + if (!m_pipelineState->isVertexStreamActive(i)) + continue; + + if (m_gfxIp.major <= 11) { + m_builder.CreateIntrinsic(Intrinsic::amdgcn_ds_add_gs_reg_rtn, primCountInSubgroup[i]->getType(), + {primCountInSubgroup[i], // value to add + m_builder.getInt32((GDS_STRMOUT_PRIMS_NEEDED_0 + 2 * i) << 2)}); // count index + + m_builder.CreateIntrinsic(Intrinsic::amdgcn_ds_add_gs_reg_rtn, m_builder.getInt32Ty(), + {m_builder.getInt32(0), // value to add + m_builder.getInt32((GDS_STRMOUT_PRIMS_WRITTEN_0 + 2 * i) << 2)}); // count index + } else { + llvm_unreachable("Not implemented!"); + } + } + + m_builder.CreateBr(endCollectPrimitiveStatsBlock); + } + + // Construct ".endCollectPrimitiveStats" block + { m_builder.SetInsertPoint(endCollectPrimitiveStatsBlock); } +} + // ===================================================================================================================== // Reads transform feedback output from LDS // diff --git a/lgc/patch/NggPrimShader.h b/lgc/patch/NggPrimShader.h index 0fbba2ed00..a69aa1ef12 100644 --- a/lgc/patch/NggPrimShader.h +++ b/lgc/patch/NggPrimShader.h @@ -322,6 +322,8 @@ class NggPrimShader { return m_ldsLayout[region].first; } + void collectPrimitiveStats(); + llvm::Value *readValueFromLds(llvm::Type *readTy, llvm::Value *ldsOffset, bool useDs128 = false); void writeValueToLds(llvm::Value *writeValue, llvm::Value *ldsOffset, bool useDs128 = false); void atomicAdd(llvm::Value *valueToAdd, llvm::Value *ldsOffset); @@ -394,9 +396,9 @@ class NggPrimShader { bool m_hasTes = false; // Whether the pipeline has tessellation evaluation shader bool m_hasGs = false; // Whether the pipeline has geometry shader - llvm::Value *m_streamOutControlBufPtr; // Stream-out control buffer pointer - llvm::Value *m_streamOutBufDescs[MaxTransformFeedbackBuffers]; // Stream-out buffer descriptors - llvm::Value *m_streamOutBufOffsets[MaxTransformFeedbackBuffers]; // Stream-out buffer offsets + llvm::Value *m_streamOutControlBufPtr = nullptr; // Stream-out control buffer pointer + llvm::Value *m_streamOutBufDescs[MaxTransformFeedbackBuffers] = {}; // Stream-out buffer descriptors + llvm::Value *m_streamOutBufOffsets[MaxTransformFeedbackBuffers] = {}; // Stream-out buffer offsets bool m_constPositionZ = false; // Whether the Z channel of vertex position data is constant