-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix a few bugs in the MPMC Queue and improve the Scheduler
- Loading branch information
devsh
committed
Oct 18, 2024
1 parent
f8f6571
commit e691879
Showing
6 changed files
with
221 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#ifndef _NBL_HLSL_MPMC_QUEUE_HLSL_ | ||
#define _NBL_HLSL_MPMC_QUEUE_HLSL_ | ||
|
||
#include "nbl/builtin/hlsl/glsl_compat/core.hlsl" | ||
#include "nbl/builtin/hlsl/type_traits.hlsl" | ||
#include "nbl/builtin/hlsl/bda/__ptr.hlsl" | ||
|
||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
|
||
// ONLY ONE INVOCATION IN A WORKGROUP USES THESE | ||
// NO OVERFLOW PROTECTION, YOU MUSTN'T OVERFLOW! | ||
template<typename T> | ||
struct MPMCQueue | ||
{ | ||
const static uint64_t ReservedOffset = 0; | ||
const static uint64_t ComittedOffset = 1; | ||
const static uint64_t PoppedOffset = 2; | ||
|
||
// we don't actually need to use 64-bit offsets/counters for the | ||
uint64_t getStorage(const uint32_t ix) | ||
{ | ||
return pStorage+(ix&((0x1u<<capacityLog2)-1))*sizeof(T); | ||
} | ||
|
||
void push(const in T val) | ||
{ | ||
// reserve output index | ||
bda::__ref<uint32_t,4> reserved = (counters+ReservedOffset).deref(); | ||
const uint32_t dstIx = spirv::atomicIAdd(reserved.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsAcquireMask,1); | ||
// write | ||
vk::RawBufferStore(getStorage(dstIx),val); | ||
// say its ready for consumption | ||
bda::__ref<uint32_t,4> committed = (counters+ComittedOffset).deref(); | ||
spirv::atomicUMax(committed.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsReleaseMask,1); | ||
} | ||
|
||
// everything here is must be done by one invocation between two barriers, all invocations must call this method | ||
// `electedInvocation` must true only for one invocation and be such that `endOffsetInPopped` has the highest value amongst them | ||
template<typename BroadcastAccessor> | ||
bool pop(BroadcastAccessor accessor, const bool active, out T val, const uint16_t endOffsetInPopped, const bool electedInvocation, const uint32_t beginHint) | ||
{ | ||
if (electedInvocation) | ||
{ | ||
uint32_t begin; | ||
uint32_t end; | ||
// strictly speaking CAS loops have FP because one WG will perform the comp-swap and make progress | ||
uint32_t expected; | ||
bda::__ref<uint32_t,4> committed = (counters+ComittedOffset).deref(); | ||
bda::__ref<uint32_t,4> popped = (counters+PoppedOffset).deref(); | ||
do | ||
{ | ||
// TOOD: replace `atomicIAdd(p,0)` with `atomicLoad(p)` | ||
uint32_t end = spirv::atomicIAdd(committed.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsAcquireReleaseMask,0u); | ||
end = min(end,begin+endOffsetInPopped); | ||
expected = begin; | ||
begin = spirv::atomicCompareExchange( | ||
popped.__get_spv_ptr(), | ||
spv::ScopeWorkgroup, | ||
spv::MemorySemanticsAcquireReleaseMask, // equal needs total ordering | ||
spv::MemorySemanticsMaskNone, // unequal no memory ordering | ||
end, | ||
expected | ||
); | ||
} while (begin!=expected); | ||
accessor.set(0,begin); | ||
accessor.set(1,end); | ||
} | ||
// broadcast the span to everyone | ||
nbl::hlsl::glsl::barrier(); | ||
bool took = false; | ||
if (active) | ||
{ | ||
uint32_t begin; | ||
uint32_t end; | ||
accessor.get(0,begin); | ||
accessor.get(1,end); | ||
begin += endOffsetInPopped; | ||
if (begin<=end) | ||
{ | ||
val = vk::RawBufferLoad<T>(getStorage(begin-1)); | ||
took = true; | ||
} | ||
} | ||
return took; | ||
} | ||
template<typename BroadcastAccessor> | ||
bool pop(BroadcastAccessor accessor, const bool active, out T val, const uint16_t endOffsetInPopped, const bool electedInvocation) | ||
{ | ||
// TOOD: replace `atomicIAdd(p,0)` with `atomicLoad(p)` | ||
const uint32_t beginHint = spirv::atomicIAdd((counters+PoppedOffset).deref().__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsMaskNone,0u); | ||
return pop(accessor,active,val,endOffsetInPopped,electedInvocation,beginHint); | ||
} | ||
|
||
bda::__ptr<uint32_t> counters; | ||
uint64_t pStorage; | ||
uint16_t capacityLog2; | ||
}; | ||
|
||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
#ifndef _NBL_HLSL_MPMC_HLSL_ | ||
#define _NBL_HLSL_MPMC_HLSL_ | ||
|
||
#include "workgroup/stack.hlsl" | ||
#include "mpmc_queue.hlsl" | ||
|
||
#include "nbl/builtin/hlsl/workgroup/scratch_size.hlsl" | ||
#include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl" | ||
#include "nbl/builtin/hlsl/bda/__ptr.hlsl" | ||
|
||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
namespace schedulers | ||
{ | ||
|
||
// TODO: improve and use a Global Pool Allocator and stop moving whole payloads around in VRAM | ||
template<typename Task, uint32_t WorkGroupSize, typename SharedAccessor, typename GlobalQueue, class device_capabilities=void> | ||
struct MPMC | ||
{ | ||
// TODO: static asset that the signature of the `Task::operator()` is `void()` | ||
static const uint16_t BallotDWORDS = workgroup::scratch_size_ballot<WorkGroupSize>::value; | ||
static const uint16_t ScanDWORDS = workgroup::scratch_size_arithmetic<WorkGroupSize>::value; | ||
static const uint16_t PopCountOffset = BallotDWORDS+ScanDWORDS; | ||
|
||
void push(const in Task payload) | ||
{ | ||
// already stole some work, need to spill | ||
if (nextValid) | ||
{ | ||
// if the shared memory stack will overflow | ||
if (!sStack.push(payload)) | ||
{ | ||
// spill to a global queue | ||
gQueue.push(payload); | ||
} | ||
} | ||
else | ||
next = payload; | ||
} | ||
|
||
// returns if there's any invocation at all that wants to pop | ||
uint16_t popCountInclusive_impl(out uint16_t reduction) | ||
{ | ||
// ballot how many items want to be taken | ||
workgroup::ballot(!nextValid,sStack.accessor); | ||
// clear the count | ||
sStack.accessor.set(PopCountOffset,0); | ||
glsl::barrier(); | ||
// prefix sum over the active requests | ||
using ArithmeticAccessor = accessor_adaptors::Offset<SharedAccessor,uint32_t,integral_constant<uint32_t,BallotDWORDS> >; | ||
ArithmeticAccessor arithmeticAccessor; | ||
arithmeticAccessor.accessor = sStack.accessor; | ||
const uint16_t retval = workgroup::ballotInclusiveBitCount<WorkGroupSize,SharedAccessor,ArithmeticAccessor,device_capabilities>(sStack.accessor,arithmeticAccessor); | ||
sStack.accessor = arithmeticAccessor.accessor; | ||
// get the reduction | ||
if (glsl::gl_LocalInvocationIndex()==(WorkGroupSize-1)) | ||
sStack.accessor.set(PopCountOffset,retval); | ||
glsl::barrier(); | ||
sStack.accessor.get(PopCountOffset,reduction); | ||
return retval; | ||
} | ||
|
||
void operator()() | ||
{ | ||
const bool lastInvocationInGroup = glsl::gl_LocalInvocationIndex()==(WorkGroupSize-1); | ||
// need to quit when we don't get any work, otherwise we'd spin expecting forward progress guarantees | ||
for (uint16_t popCount=0xffffu; popCount; ) | ||
{ | ||
if (nextValid) // this invocation has some work to do | ||
{ | ||
// ensure by-value semantics, the task may push work itself | ||
Task tmp = next; | ||
nextValid = false; | ||
tmp(); | ||
} | ||
// everyone sync up here so we can count how many invocations won't have jobs | ||
glsl::barrier(); | ||
uint16_t popCountInclusive = popCountInclusive_impl(popCount); | ||
// now try and pop work from out shared memory stack | ||
if (popCount > sharedAcceptableIdleCount) | ||
{ | ||
// look at the way the `||` is expressed, its specifically that way to avoid short circuiting! | ||
nextValid = sStack.pop(!nextValid,next,popCountInclusive,lastInvocationInGroup) || nextValid; | ||
// now if there's still a problem, grab some tasks from the global ring-buffer | ||
popCountInclusive = popCountInclusive_impl(popCount); | ||
if (popCount > globalAcceptableIdleCount) | ||
{ | ||
// reuse the ballot smem for broadcasts, nobody need the ballot state now | ||
gQueue.pop(sStack.accessor,!nextValid,next,popCountInclusive,lastInvocationInGroup,0); | ||
} | ||
} | ||
} | ||
} | ||
|
||
MPMCQueue<Task> gQueue; | ||
workgroup::Stack<Task,SharedAccessor,PopCountOffset+1> sStack; | ||
Task next; | ||
// popping work from the stack and queue might be expensive, expensive enough to not justify doing all the legwork to just pull a few items of work | ||
uint16_t sharedAcceptableIdleCount; | ||
uint16_t globalAcceptableIdleCount; | ||
bool nextValid; | ||
}; | ||
|
||
} | ||
} | ||
|
||
} | ||
} | ||
} | ||
#endif |
Oops, something went wrong.