Skip to content

Commit

Permalink
Add warp specialization as a circular buffering type (#3511)
Browse files Browse the repository at this point in the history
This PR adds warp specialization as a new type of circular buffering.

Today, we already support pipelined circular buffer, and optionally, we
could choose whether we want to use block-sync or mbarrier for handling
WAR hazards. If we choose to use mbarrier for handling WAR harzard, then
we will generate kernel like below:

```python
# Mark buffer[i] as empty and ready to be loaded
for i in range(prefetch + 1):
  arrive(war_mbarrier[i])
# Prologue: thanks to the previous arrives, all the loads will just go through and no wait needed
for i in range(prefetch):
  wait war_mbarrier[i]
  arrive-expect-tx raw_mbarrier[i]
  load data[i] into buffer[i]
# Main loop:
for i in range(data.size - prefetch):
  if elect-sync:
    wait war_mbarrier[(i + prefetch) % stage]
    arrive-expect-tx raw_mbarrier[(i + prefetch) % stage]
    load data[i + prefetch] to buffer[(i + prefetch) % stage]
  wait raw_mbarrier[i % stage]
  mma on buffer[i % stage] for data[i]
  wait until there are at most stage - prefetch - 1 pending mma
  arrive war_mbarrier[(i + prefetch + 1) % stage]
# Epilogue
for i in range(data.size - prefetch, data.size):
  wait raw_mbarrier[i % stage]
  mma on buffer[i % stage] for data[i]
wait until there are at most 0 pending mma
write result back to gmem
```

The kernel above has the following problems:
1. The MMA loop is not clean. There is one thread doing an extra work of
loading, while other threads in the warp groups just waiting this one
thread to finish. (Note that mma is a warp-group collective, so all
threads in the warp group must arrive that instruction for it to start).
Ideally, we should have a for loop with only mma, and nothing else.
Having extra instructions could increase the latency.
2. There is a false dependency between the loading of `data[i +
prefetch]` and the computing of `data[i]`. These two things are not
dealing with the same data, so in theory, they should not depend on each
other, and whoever gets its mbarrier cleared first should go first.
However, just because codes are executed from top to bottom, the mma has
to wait until the load is issued. This further increases latency.

With the above problem observed, it is naturally to ask: why not use
different warps for load and compute? The load code and the compute code
in the main loop are completely independent, and both the RAW and WAR
are handled by mbarrier, which is on smem and accessible across the
entire CTA, so all the preconditions for warp specialization are mature,
and we just need to put different IR nodes into different places.

This PR adds warp specialization. The generated code is similar to the
pipelined code that uses mbarrier for WAR, but actually simpler. The
code looks like below (assuming doing warp specialization on TIDy):

```python
if threadIdx.y == blockDim.y - 1:
  # If we use warp specialization on TIDy, then the blockDim.y of the
  # kernel will be (whatever_value_inferred_from_schedule + 1), and the
  # last threadIdx.y will be used as load warp
  for i in range(data.size):
    wait war_mbarrier[i % stage]
    load data[i] to buffer[i % stage]
else:
  # Every threadIdx.y other than the last will be used for compute
  for i in range(prefetch + 1):
    arrive war_mbarrier[i % stage]
  for i in range(data.size):
    wait raw_mbarrier[i % stage]
    compute buffer[i % stage]
    wait until there are at most stage - prefetch - 1 pending mma
    arrive war_mbarrier[(i + prefetch + 1) % stage]
```

This new way of doing circular buffering is intended to be
computation-agnostic, it should work on whatever kernel we are
scheduling, instead of just matmuls. But note that today, there are some
strong limitations that makes it less applicable:

1. The computation can not have hardcoded `blockDim` in it. So block
reduction will not work. I believe this will be easy to fix, but it is
beyond the scope of this PR.
2. Because the warp-specialized parallel type will no longer be exact,
there will be thread predicates generated for it. Predication
elimination is not yet smart enough to know that this is in the compute
warp, so already predicated and not need to predicate it again. This
limitation also means, the computation can not be tensor core operations
(`MmaOp`), so this PR actually does not work with matmul.

Besides the above limitation, I believe this new circular buffer type is
pretty generic, and in the future, I believe we should be able to try it
with TMA in perf tuning.

---------

Co-authored-by: Ryan Spring <[email protected]>
  • Loading branch information
zasdfgbnm and rdspring1 authored Dec 5, 2024
1 parent 1dda106 commit 64bc560
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 31 deletions.
4 changes: 3 additions & 1 deletion csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ Expr* initializeMbarrier(
// threads in the CTA.
num_of_arrives = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32,
GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock());
GpuLower::current()
->parallelDimensionMap()
.getNumComputeThreadsEachBlock());
}

// Initialize mbarrier for each circular buffer stage. Use the thread
Expand Down
56 changes: 52 additions & 4 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class CircularBufferLoopCloner : public kir::IrVisitor {
SimplifyingIrBuilder::create<Val>(opt.prefetch, DataType::Index));
break;
}
case CircularBufferLoopStage::LoadWarp:
case CircularBufferLoopStage::ComputeWarp: {
break;
}
default: {
NVF_THROW("Unsupported loop mode, got: ", loop_type_);
}
Expand Down Expand Up @@ -1246,11 +1250,22 @@ class CircularBufferInserter : private kir::ExprMutator {
return;
}

auto hasCpAsyncBulk = std::any_of(
auto has_cp_async_bulk = std::any_of(
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk);

if (hasCpAsyncBulk) {
insertTma(loop, it->second);
bool use_warp_specialization = std::holds_alternative<WarpSpecialized>(
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(loop->iter_domain())
.type);
if (use_warp_specialization) {
NVF_ERROR(
std::all_of(
it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk),
"In order to use warp specialization, all buffers must be loaded by TMA");
insertTmaWarpSpecialized(loop, it->second);
} else if (has_cp_async_bulk) {
insertTmaPipelined(loop, it->second);
} else {
insert(loop, it->second);
}
Expand Down Expand Up @@ -1315,7 +1330,7 @@ class CircularBufferInserter : private kir::ExprMutator {
.usesMBarrierForWAR();
}

void insertTma(
void insertTmaPipelined(
ForLoop* circular_buffer_loop,
const std::vector<Expr*>& loads) {
// Arrive on the WAR mbarriers to let the prefetching start.
Expand Down Expand Up @@ -1363,6 +1378,39 @@ class CircularBufferInserter : private kir::ExprMutator {
registerInsertAfter(circular_buffer_loop, epilogue_loop);
}

void insertTmaWarpSpecialized(
ForLoop* circular_buffer_loop,
const std::vector<Expr*>& loads) {
const auto& opt =
GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor(
circular_buffer_loop->iter_domain());
ParallelType warp_specialize_on = std::get<WarpSpecialized>(opt.type).on;

kir::IfThenElse* warp_dispatch_ite = IrBuilder::create<kir::IfThenElse>(
IrBuilder::create<kir::Predicate>(IrBuilder::eqExpr(
NamedScalar::getParallelIndex(warp_specialize_on),
IrBuilder::subExpr(
GpuLower::current()->parallelDimensionMap().get(
warp_specialize_on),
circular_buffer_loop->fusion()->oneVal()))));

// Load loop:
ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp);
warp_dispatch_ite->thenBody().push_back(load_loop);

// Prefetch:
auto prefetch_loop = createArrivesForWar(circular_buffer_loop);
warp_dispatch_ite->elseBody().push_back(prefetch_loop);

// Compute loop:
ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone(
circular_buffer_loop, loads, CircularBufferLoopStage::ComputeWarp);
warp_dispatch_ite->elseBody().push_back(compute_loop);

registerReplace(circular_buffer_loop, warp_dispatch_ite);
}

void insert(ForLoop* circular_buffer_loop, const std::vector<Expr*>& loads) {
NVF_ERROR(
!usesMBarrierForWAR(circular_buffer_loop),
Expand Down
12 changes: 12 additions & 0 deletions csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,18 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
last_writes_.pop_front();
// Found that a sync is needed

if (!sync_bitmap.hasBID() &&
std::all_of(
expr->inputs().begin(), expr->inputs().end(), [](Val* val) {
return !val->isA<TensorView>() ||
val->as<TensorView>()->getMemoryType() !=
MemoryType::Shared ||
ir_utils::isCpAsyncBulkLoad(val->definition());
})) {
// RAW of TMA is handled separately, so skip it here.
return;
}

// TODO: Explicitly test the 3 cases below
Expr* sync_expr = nullptr;
kir::Allocate* maybe_alloc = nullptr;
Expand Down
36 changes: 30 additions & 6 deletions csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,17 +246,41 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);

auto load_warp_loop_it =
std::find_if(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) {
return fl->circularBufferLoopStage() ==
CircularBufferLoopStage::LoadWarp;
});
ParallelType load_warp_on = ParallelType::Serial;
if (load_warp_loop_it != for_loops_.end()) {
load_warp_on = std::get<WarpSpecialized>(
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor(
(*load_warp_loop_it)->iter_domain())
.type)
.on;
}

// If we are in a load warp, then the warp-dispatching IfThenElse
// already selects on `load_warp_on`, so we should not generate
// predicates for it here.
const auto& pdim_map = GpuLower::current()->parallelDimensionMap();
Val* first_warp = IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
Val* conditional = load_warp_on == ParallelType::TIDx
? pred->fusion()->trueVal()
: SimplifyingIrBuilder::logicalAndExpr(
elect_sync_val,
IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx),
warp_size));
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (pdim_map.has(pt)) {
first_warp = SimplifyingIrBuilder::logicalAndExpr(
first_warp,
if (pdim_map.has(pt) && load_warp_on != pt) {
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
}
}
return SimplifyingIrBuilder::logicalAndExpr(first_warp, elect_sync_val);
return conditional;
}
default:
break;
Expand Down
90 changes: 81 additions & 9 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class TVDomainGuard;

//
// /load 0;\ \.
// / load 1; [prefetch = 3] | [prologue]
// / load 1; [prefetch = 3] | [prefetching]
// [stage] load 2;/ /'
// [ = 6 ] load 3; wait load 0; compute 0; \.
// \ load 4; wait load 1; compute 1; |
Expand All @@ -123,7 +123,7 @@ class TVDomainGuard;
// load 2; wait load 5; compute 5; wait compute 3; |
// load 3; wait load 0; compute 0; wait compute 4; |
// load 4; wait load 1; compute 1; wait compute 5; | [main]
// load 5; wait load 2; compute 2; wait compute 0; | [loop]
// load 5; wait load 2; compute 2; wait compute 0; |
// .................................................. |
// .................................................. |
// .................................................. |
Expand All @@ -132,7 +132,7 @@ class TVDomainGuard;
// load ; wait load ; compute ; wait compute ; |
// load ; wait load ; compute ; /'
// /wait load ; compute ; \.
// [same number as prefetch] wait load ; compute ; | [epilogue]
// [same number as prefetch] wait load ; compute ; | [draining]
// \wait load ; compute ; wait all computes; /'

// clang-format on
Expand All @@ -142,19 +142,37 @@ class TVDomainGuard;
// load pipeline depth = prefetch + 1
// compute pipeline depth = stage - prefetch
//
// The above timeline can be implemented as the following loop structure:
// There are two ways to implement the above timeline: pipelined, and
// warp-specialization.
//
// In the pipelined way, the prefetching stage is implemented as a prologue
// loop, and main stage is implemented as a main loop, and the draining stage is
// implemented as an epilogue loop. That is, we will have the following loop
// structure:
//
// Prologue loop:
// for i in range(prefetch):
// load data[i] to buffer[i]
//
// Main loop:
// Main loop (using syncthreads to avoid WAR harzard):
// for i in range(data.size - prefetch):
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
// wait buffer[i % stage] to be ready
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// __syncthreads();
//
// Main loop (using mbarrier to avoid WAR harzard):
// for i in range(data.size - prefetch):
// wait buffer[(i + prefetch) % stage] to be empty
// load data[i + prefetch] to buffer[(i + prefetch) % stage]
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
// loaded again
//
// Epilogue loop:
// for i in range(data.size - prefetch, data.size):
Expand All @@ -166,6 +184,30 @@ class TVDomainGuard;
// stage - prefetch - 1 iterations and last iteration of the main loop is
// redundant. We can remove them to further optimize the performance, but
// we decide to keep them for simplicity.
//
// In the warp-specialized approach, we will use different warp/warp-group
// for loading and computing. We will generate code like below (assuming warp
// specialized on TIDy):
//
// if (threadIdx.y == blockDim.y - 1) {
// // If we use warp specialization on TIDy, then the blockDim.y of the
// // kernel will be (whatever_value_inferred_from_schedule + 1), and the
// // last threadIdx.y will be used as load warp
// for i in range(data.size):
// wait buffer[i % stage] to be empty
// load data[i] to buffer[i % stage]
// } else {
// // Every threadIdx.y other than the last will be used for compute
// for i in range(prefetch + 1):
// signal that buffer i % stage is empty and ready to load
// for i in range(data.size):
// wait buffer[i % stage] to be loaded
// compute buffer[i % stage]
// wait until the first compute in the queue is done
// (i.e. stage - prefetch - 1 in flight computes remaining)
// signal that buffer (i + prefetch + 1) % stage is empty and ready to be
// loaded again
// }

struct Pipelined {
bool uses_mbarrier_for_war = false;
Expand All @@ -184,7 +226,36 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) {
return os << "Pipelined";
}

using CircularBufferType = std::variant<Pipelined>;
struct WarpSpecialized {
ParallelType on;
explicit WarpSpecialized(ParallelType on) : on(on) {}
WarpSpecialized() = default;
bool operator==(const WarpSpecialized& other) const {
return on == other.on;
}
};

inline std::ostream& operator<<(
std::ostream& os,
const WarpSpecialized& warp_specialized) {
std::string parallel_type_str = "";
switch (warp_specialized.on) {
case ParallelType::TIDx:
parallel_type_str = "TIDx";
break;
case ParallelType::TIDy:
parallel_type_str = "TIDy";
break;
case ParallelType::TIDz:
parallel_type_str = "TIDz";
break;
default:
NVF_THROW("Invalid parallel type");
}
return os << "WarpSpecializedOn" << parallel_type_str;
}

using CircularBufferType = std::variant<Pipelined, WarpSpecialized>;

inline std::ostream& operator<<(
std::ostream& os,
Expand All @@ -207,8 +278,9 @@ struct CircularBufferOptions {
}

bool usesMBarrierForWAR() const {
return std::holds_alternative<Pipelined>(type) &&
std::get<Pipelined>(type).uses_mbarrier_for_war;
return (std::holds_alternative<Pipelined>(type) &&
std::get<Pipelined>(type).uses_mbarrier_for_war) ||
std::holds_alternative<WarpSpecialized>(type);
return false;
}

Expand Down
28 changes: 27 additions & 1 deletion csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ struct hash<PAndID> {
namespace nvfuser {

void ParallelDimensionMap::build(Fusion* fusion) {
VectorOfUniqueEntries<ParallelType> warp_specialized_types;
VectorOfUniqueEntries<PAndID> all_concrete_ids;
auto all_vals = fusion->usedMathVals();
for (auto tv : ir_utils::filterByType<TensorView>(all_vals)) {
if (tv->isCircularBuffered() &&
std::holds_alternative<WarpSpecialized>(
tv->circularBufferOptions().type)) {
const auto& warp_specialized =
std::get<WarpSpecialized>(tv->circularBufferOptions().type);
warp_specialized_types.pushBack(warp_specialized.on);
}
for (auto id : tv->domain()->allIDs()) {
auto ptype = id->getParallelType();
if (!isParallelTypeThread(ptype)) {
Expand Down Expand Up @@ -83,6 +91,10 @@ void ParallelDimensionMap::build(Fusion* fusion) {
}

adjustMappingsForWarpPadding();

for (auto pt : warp_specialized_types) {
setWarpSpecializeOn(pt);
}
}

void ParallelDimensionMap::adjustMappingsForWarpPadding() {
Expand Down Expand Up @@ -137,6 +149,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() {
exact_types_.erase(ParallelType::TIDx);
}

void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) {
auto dim_it = dim_map_.find(pt);
if (dim_it == dim_map_.end()) {
dim_map_[pt] = IrBuilder::create<Val>(2, DataType::Index);
} else {
dim_map_[pt] = SimplifyingIrBuilder::addExpr(dim_it->second, 1);
}
exact_types_.erase(pt);
warp_specialized_types_.insert(pt);
}

Val* ParallelDimensionMap::getRaw(ParallelType pt) const {
NVF_ERROR(isParallelTypeThread(pt), "Invalid ParallelType: ", pt);
auto it = dim_map_.find(pt);
Expand All @@ -159,13 +182,16 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const {
return exact_types_.find(pt) != exact_types_.end();
}

Val* ParallelDimensionMap::getNumThreadsEachBlock() const {
Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
Val* num_threads = FusionGuard::getCurFusion()->oneVal();
for (auto pt : kParallelTypeTIDs) {
auto dim = getRaw(pt);
if (dim == nullptr) {
continue;
}
if (warp_specialized_types_.find(pt) != warp_specialized_types_.end()) {
dim = SimplifyingIrBuilder::addExpr(dim, -1);
}
num_threads = SimplifyingIrBuilder::mulExpr(num_threads, dim);
}
return num_threads;
Expand Down
Loading

0 comments on commit 64bc560

Please sign in to comment.