Skip to content

Commit

Permalink
fix a few bugs in the MPMC Queue and improve the Scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
devsh committed Oct 18, 2024
1 parent f8f6571 commit e691879
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 266 deletions.
104 changes: 104 additions & 0 deletions 27_MPMCScheduler/app_resources/mpmc_queue.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#ifndef _NBL_HLSL_MPMC_QUEUE_HLSL_
#define _NBL_HLSL_MPMC_QUEUE_HLSL_

#include "nbl/builtin/hlsl/glsl_compat/core.hlsl"
#include "nbl/builtin/hlsl/type_traits.hlsl"
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"

namespace nbl
{
namespace hlsl
{

// ONLY ONE INVOCATION IN A WORKGROUP USES THESE
// NO OVERFLOW PROTECTION, YOU MUSTN'T OVERFLOW!
template<typename T>
struct MPMCQueue
{
const static uint64_t ReservedOffset = 0;
const static uint64_t ComittedOffset = 1;
const static uint64_t PoppedOffset = 2;

// we don't actually need to use 64-bit offsets/counters for the
uint64_t getStorage(const uint32_t ix)
{
return pStorage+(ix&((0x1u<<capacityLog2)-1))*sizeof(T);
}

void push(const in T val)
{
// reserve output index
bda::__ref<uint32_t,4> reserved = (counters+ReservedOffset).deref();
const uint32_t dstIx = spirv::atomicIAdd(reserved.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsAcquireMask,1);
// write
vk::RawBufferStore(getStorage(dstIx),val);
// say its ready for consumption
bda::__ref<uint32_t,4> committed = (counters+ComittedOffset).deref();
spirv::atomicUMax(committed.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsReleaseMask,1);
}

// everything here is must be done by one invocation between two barriers, all invocations must call this method
// `electedInvocation` must true only for one invocation and be such that `endOffsetInPopped` has the highest value amongst them
template<typename BroadcastAccessor>
bool pop(BroadcastAccessor accessor, const bool active, out T val, const uint16_t endOffsetInPopped, const bool electedInvocation, const uint32_t beginHint)
{
if (electedInvocation)
{
uint32_t begin;
uint32_t end;
// strictly speaking CAS loops have FP because one WG will perform the comp-swap and make progress
uint32_t expected;
bda::__ref<uint32_t,4> committed = (counters+ComittedOffset).deref();
bda::__ref<uint32_t,4> popped = (counters+PoppedOffset).deref();
do
{
// TOOD: replace `atomicIAdd(p,0)` with `atomicLoad(p)`
uint32_t end = spirv::atomicIAdd(committed.__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsAcquireReleaseMask,0u);
end = min(end,begin+endOffsetInPopped);
expected = begin;
begin = spirv::atomicCompareExchange(
popped.__get_spv_ptr(),
spv::ScopeWorkgroup,
spv::MemorySemanticsAcquireReleaseMask, // equal needs total ordering
spv::MemorySemanticsMaskNone, // unequal no memory ordering
end,
expected
);
} while (begin!=expected);
accessor.set(0,begin);
accessor.set(1,end);
}
// broadcast the span to everyone
nbl::hlsl::glsl::barrier();
bool took = false;
if (active)
{
uint32_t begin;
uint32_t end;
accessor.get(0,begin);
accessor.get(1,end);
begin += endOffsetInPopped;
if (begin<=end)
{
val = vk::RawBufferLoad<T>(getStorage(begin-1));
took = true;
}
}
return took;
}
template<typename BroadcastAccessor>
bool pop(BroadcastAccessor accessor, const bool active, out T val, const uint16_t endOffsetInPopped, const bool electedInvocation)
{
// TOOD: replace `atomicIAdd(p,0)` with `atomicLoad(p)`
const uint32_t beginHint = spirv::atomicIAdd((counters+PoppedOffset).deref().__get_spv_ptr(),spv::ScopeWorkgroup,spv::MemorySemanticsMaskNone,0u);
return pop(accessor,active,val,endOffsetInPopped,electedInvocation,beginHint);
}

bda::__ptr<uint32_t> counters;
uint64_t pStorage;
uint16_t capacityLog2;
};

}
}
#endif
112 changes: 112 additions & 0 deletions 27_MPMCScheduler/app_resources/schedulers/mpmc.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#ifndef _NBL_HLSL_MPMC_HLSL_
#define _NBL_HLSL_MPMC_HLSL_

#include "workgroup/stack.hlsl"
#include "mpmc_queue.hlsl"

#include "nbl/builtin/hlsl/workgroup/scratch_size.hlsl"
#include "nbl/builtin/hlsl/workgroup/arithmetic.hlsl"
#include "nbl/builtin/hlsl/bda/__ptr.hlsl"

namespace nbl
{
namespace hlsl
{
namespace schedulers
{

// TODO: improve and use a Global Pool Allocator and stop moving whole payloads around in VRAM
template<typename Task, uint32_t WorkGroupSize, typename SharedAccessor, typename GlobalQueue, class device_capabilities=void>
struct MPMC
{
// TODO: static asset that the signature of the `Task::operator()` is `void()`
static const uint16_t BallotDWORDS = workgroup::scratch_size_ballot<WorkGroupSize>::value;
static const uint16_t ScanDWORDS = workgroup::scratch_size_arithmetic<WorkGroupSize>::value;
static const uint16_t PopCountOffset = BallotDWORDS+ScanDWORDS;

void push(const in Task payload)
{
// already stole some work, need to spill
if (nextValid)
{
// if the shared memory stack will overflow
if (!sStack.push(payload))
{
// spill to a global queue
gQueue.push(payload);
}
}
else
next = payload;
}

// returns if there's any invocation at all that wants to pop
uint16_t popCountInclusive_impl(out uint16_t reduction)
{
// ballot how many items want to be taken
workgroup::ballot(!nextValid,sStack.accessor);
// clear the count
sStack.accessor.set(PopCountOffset,0);
glsl::barrier();
// prefix sum over the active requests
using ArithmeticAccessor = accessor_adaptors::Offset<SharedAccessor,uint32_t,integral_constant<uint32_t,BallotDWORDS> >;
ArithmeticAccessor arithmeticAccessor;
arithmeticAccessor.accessor = sStack.accessor;
const uint16_t retval = workgroup::ballotInclusiveBitCount<WorkGroupSize,SharedAccessor,ArithmeticAccessor,device_capabilities>(sStack.accessor,arithmeticAccessor);
sStack.accessor = arithmeticAccessor.accessor;
// get the reduction
if (glsl::gl_LocalInvocationIndex()==(WorkGroupSize-1))
sStack.accessor.set(PopCountOffset,retval);
glsl::barrier();
sStack.accessor.get(PopCountOffset,reduction);
return retval;
}

void operator()()
{
const bool lastInvocationInGroup = glsl::gl_LocalInvocationIndex()==(WorkGroupSize-1);
// need to quit when we don't get any work, otherwise we'd spin expecting forward progress guarantees
for (uint16_t popCount=0xffffu; popCount; )
{
if (nextValid) // this invocation has some work to do
{
// ensure by-value semantics, the task may push work itself
Task tmp = next;
nextValid = false;
tmp();
}
// everyone sync up here so we can count how many invocations won't have jobs
glsl::barrier();
uint16_t popCountInclusive = popCountInclusive_impl(popCount);
// now try and pop work from out shared memory stack
if (popCount > sharedAcceptableIdleCount)
{
// look at the way the `||` is expressed, its specifically that way to avoid short circuiting!
nextValid = sStack.pop(!nextValid,next,popCountInclusive,lastInvocationInGroup) || nextValid;
// now if there's still a problem, grab some tasks from the global ring-buffer
popCountInclusive = popCountInclusive_impl(popCount);
if (popCount > globalAcceptableIdleCount)
{
// reuse the ballot smem for broadcasts, nobody need the ballot state now
gQueue.pop(sStack.accessor,!nextValid,next,popCountInclusive,lastInvocationInGroup,0);
}
}
}
}

MPMCQueue<Task> gQueue;
workgroup::Stack<Task,SharedAccessor,PopCountOffset+1> sStack;
Task next;
// popping work from the stack and queue might be expensive, expensive enough to not justify doing all the legwork to just pull a few items of work
uint16_t sharedAcceptableIdleCount;
uint16_t globalAcceptableIdleCount;
bool nextValid;
};

}
}

}
}
}
#endif
Loading

0 comments on commit e691879

Please sign in to comment.