Skip to content

Commit

Permalink
Only the TMA thread arrive (#3294)
Browse files Browse the repository at this point in the history
Previously:
```C++
if (elect-sync) {
  arriveExpectTx
  TMA
} else {
  arrive
}
```
Now:
```C++
if (elect-sync) {
  arriveExpectTx
  TMA
}
```

I am very surprised that this fixes all the latencies introduced in the
elect-sync fix #3295, and even
better! But in general, we should sync as less as possible, and avoid
unnecessary wait, so I think this PR makes sense.

Perf:
```C++
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)
 Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     39.0           172735          1  172735.0  172735.0    172735    172735          0.0  <unnamed>::nvfuser_none_f0_c0_r0_g0(<unnamed>::Tensor<<unnamed>::__half, (int)3, (int)3>, <unnamed>…
     20.0            88768          1   88768.0   88768.0     88768     88768          0.0  nvjet_hsh_256x128_64x4_1x2_h_bz_coopA_NTT
```

Perf nvFuser/cuBLAS: `51.4%`.
  • Loading branch information
zasdfgbnm authored Oct 29, 2024
1 parent 7a3b1a4 commit e33316d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 47 deletions.
23 changes: 2 additions & 21 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,6 @@ Expr* initializeMbarrier(
kir::TensorIndex* stage_mbarrier =
IrBuilder::create<kir::TensorIndex>(all_mbarriers, loop->index());

// Get all threads in CTA
Val* bdimx =
GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDx);
Val* bdimy =
GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDy);
Val* bdimz =
GpuLower::current()->parallelDimensionMap().get(ParallelType::TIDz);
Val* all_threads_in_cta = SimplifyingIrBuilder::mulExpr(
bdimx, SimplifyingIrBuilder::mulExpr(bdimy, bdimz));
if (all_threads_in_cta != nullptr) {
all_threads_in_cta = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32, all_threads_in_cta);
} else {
// If all_threads_in_cta is nullptr, then this kernel is not parallelized
// on any of the thread dimensions.
all_threads_in_cta =
GpuLower::current()->kernel()->oneVal(DataType::UInt32);
}

auto circular_buffered_tvs =
GpuLower::current()->circularBufferInfo().getCircularBufferTvs(
circular_buffer_loop);
Expand All @@ -95,8 +76,8 @@ Expr* initializeMbarrier(
[](const TensorView* tv) {
return ir_utils::isCpAsyncBulkLoad(tv->definition());
});
Val* n = IrBuilder::create<Val>(num_of_tvs_loaded_by_tma, DataType::UInt32);
Val* num_of_arrives = SimplifyingIrBuilder::mulExpr(n, all_threads_in_cta);
Val* num_of_arrives =
IrBuilder::create<Val>(num_of_tvs_loaded_by_tma, DataType::UInt32);

// Initialize mbarrier for each circular buffer stage. Use the thread
// count from the MBarrierInit created in the allocation pass. The wait
Expand Down
30 changes: 4 additions & 26 deletions csrc/device_lower/pass/circular_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,27 +261,20 @@ class CircularBufferLoopCloner : public kir::IrVisitor {
// Detailed Pseudo-Code:
// Pre-Prologue loop:
//
// - number_of_arrival_threads is the number of threads to call
// mbarrier::arrive or mbarrier::arriveExpectTx and to wait at
// mbarrier:wait.
//
// __shared__ __mbarrier_t barriers[num_stages];
// if (warp_id == 0 && electSync()()) {
// for (int64_t loop_index : irange(stages)) {
// int64_t number_of_arrive_threads = blockDim.x * blockDim.y * blockDim.z;
// mbarrier_init(mbarrier[loop_index], number_of_arrival_threads);
// mbarrier_init(mbarrier[loop_index], number_of_tma_load_exprs);
// }
// }
//
// Prologue loop:
// for (int64_t loop_index : irange(prefetch_distance)) {
// if (warp_id == 0 && electSync()()) {
// if (warp_id == 0 && electSync()) {
// mbarrier::arriveExpectTx(mbarrier[loop_index], expected_bytes);
// for (...) {
// cpAsyncBulk(mbarriers[loop_index], ...);
// }
// } else {
// mbarrier::arrive(mbarrier[loop_index]);
// }
// }
//
Expand All @@ -294,8 +287,6 @@ class CircularBufferLoopCloner : public kir::IrVisitor {
// for (...) {
// cpAsyncBulk(mbarrier[load_stage], ...);
// }
// } else {
// mbarrier::arrive(mbarrier[load_stage]);
// }
// mbarrier::waitParity((loop_index / stage_depth) % 2);
//
Expand Down Expand Up @@ -363,8 +354,8 @@ class CloneTmaCircularBufferLoopAndInsertSync
// generate the nested for-loops for the serial IterDomains, but do not add
// them to the cloned circular buffer loop immediately. Once the cloned
// circular buffer loop is the only loop in the stack, add the arriveExpectTx
// and arrive expressions, then the nested for-loop structure calling the TMA
// load operations, and finally the mbarrier_wait.
// expressions, then the nested for-loop structure calling the TMA load
// operations, and finally the mbarrier_wait.
void processForLoop(ForLoop* cloned_loop) final {
// Skip if there is not an active for-loop structure
if (for_loop_stack_.empty()) {
Expand Down Expand Up @@ -412,8 +403,6 @@ class CloneTmaCircularBufferLoopAndInsertSync
// for (...) {
// cpAsyncBulk;
// }
// } else {
// arrive;
// }
NVF_ERROR(for_loop_stack_.front() == cloned_top_level_loop_);
addTmaLoadBlock(cloned_loop);
Expand Down Expand Up @@ -602,8 +591,6 @@ class CloneTmaCircularBufferLoopAndInsertSync
// for (...) {
// cpAsyncBulk(mbarriers[loop_index], ...);
// }
// } else {
// mbarrier::arrive(mbarrier[loop_index]);
// }
// }
void handlePrologueLoop(Expr* expr) {
Expand Down Expand Up @@ -656,8 +643,6 @@ class CloneTmaCircularBufferLoopAndInsertSync
// for (...) {
// cpAsyncBulk(mbarrier[load_stage], ...);
// }
// } else {
// mbarrier::arrive(mbarrier[load_stage]);
// }
// mbarrier::wait((loop_index / stage_depth) % 2);
//
Expand Down Expand Up @@ -725,8 +710,6 @@ class CloneTmaCircularBufferLoopAndInsertSync
// for (...) {
// cpAsyncBulk(mbarrier[next_stage], ...);
// }
// } else {
// mbarrier::arrive(mbarrier[next_stage]);
// }
//
// The expr input argument can be a single cpAsyncBulk expression or a nested
Expand All @@ -745,11 +728,6 @@ class CloneTmaCircularBufferLoopAndInsertSync
// launches the TMA load.
if_expr->thenBody().push_back(mbarrier_arrive_tx_);
if_expr->thenBody().push_back(expr);

// The other threads issue arriveExpectTx without any expected transactions.
kir::MBarrierArrive* thread_arrive = IrBuilder::create<kir::MBarrierArrive>(
/*state=*/nullptr, mbarrier_arrive_tx_->mbarrier());
if_expr->elseBody().push_back(thread_arrive);
for_loop_stack_.back()->body().push_back(if_expr);

mbarrier_arrive_tx_ = nullptr;
Expand Down

0 comments on commit e33316d

Please sign in to comment.