Skip to content

Commit

Permalink
Indexing changes for TMA Circular Buffering (#2825)
Browse files Browse the repository at this point in the history
## Summary ##
It is the changes to the indexing lowering pass from
#2773. It is stacked on #2824.
Tracking Branch: #2773

## Details ##
- In the circular buffering pass, we manually index the mbarriers and
tokens using the index of the prologue, main, and epilogue loops.
```cpp
for (int index : c10::irange(fl->extent()) {
  int stage = index % number_of_pipeline_stages;
  mbarrier_t current_stage_mbarrier = mbarriers[stage];  // represented with kir::TensorIndex

  int next_stage = (index + number_of_stages - 1) % number_of_pipeline_stages;
  mbarrier_t next_stage_mbarrier = mbarriers[next_stage];  // represented with kir::TensorIndex
}
```
- The handle functions for `kir::MBarrierInit`,
`kir::MBarrierInvalidate`, `kir::MBarrierArriveExpectTx`, and
`kir::MBarrierWait` are modified to handle `kir::TensorIndex`.
- `u32IndexScalarSmemTv` is modified to get the shared memory pointer
address for a `kir::TensorIndex`.
  • Loading branch information
rdspring1 authored Sep 5, 2024
1 parent 981ef1a commit 752c0fe
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 28 deletions.
135 changes: 107 additions & 28 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,45 +1404,124 @@ void IndexLowering::handleGroupedGridWelford(
}

void IndexLowering::handle(const kir::MBarrierInit* minit) {
auto minit_indexed = IrBuilder::create<kir::MBarrierInit>(
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>()),
minit->threadCount());
Val* smem_address_ptr = nullptr;

if (minit->mbarrier()->isA<TensorView>()) {
smem_address_ptr =
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>());
} else if (minit->mbarrier()->isA<kir::TensorIndex>()) {
smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
minit->mbarrier()->as<kir::TensorIndex>());
} else {
NVF_ERROR(false, "Unexpected MBarrierInit value.");
}
kir::MBarrierInit* minit_indexed = IrBuilder::create<kir::MBarrierInit>(
smem_address_ptr, minit->threadCount());
pushBack(minit_indexed);
GpuLower::current()->propagateExprInfo(minit, minit_indexed);
}

void IndexLowering::handle(const kir::MBarrierInvalidate* minval) {
auto minval_indexed = IrBuilder::create<kir::MBarrierInvalidate>(
lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as<TensorView>()));
Val* smem_address_ptr = nullptr;

if (minval->mbarrier()->isA<TensorView>()) {
smem_address_ptr =
lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as<TensorView>());
} else if (minval->mbarrier()->isA<kir::TensorIndex>()) {
smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
minval->mbarrier()->as<kir::TensorIndex>());
} else {
NVF_ERROR(
false,
"Unexpected MBarrierInvalidate barrier value: ",
minval->mbarrier()->toString());
}
kir::MBarrierInvalidate* minval_indexed =
IrBuilder::create<kir::MBarrierInvalidate>(smem_address_ptr);
pushBack(minval_indexed);
GpuLower::current()->propagateExprInfo(minval, minval_indexed);
}

void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
// indexing mbarrier
auto mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst);
auto mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// arrive and expect_tx mbarrier
auto state = IrBuilder::create<Val>(DataType::UInt);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
void IndexLowering::handle(
const kir::MBarrierArriveExpectTx* arrive_transaction) {
NVF_ERROR(
arrive_transaction->mbarrier()->isA<kir::TensorIndex>(),
"Expected kir::TensorIndex in MBarrierArriveExpectTx");

Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
arrive_transaction->mbarrier()->as<kir::TensorIndex>());
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));
arrive_transaction->state(),
smem_address_ptr,
arrive_transaction->txCount()));
}

// indexing ldst op
auto out = lowerDstIndex(ldst->out(), {}, true);
auto new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);
GpuLower::current()->propagateExprInfo(ldst, back());
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));
void IndexLowering::handle(const kir::MBarrierWait* mwait) {
NVF_ERROR(
mwait->mbarrier()->isA<kir::TensorIndex>(),
"Expected kir::TensorIndex in MBarrierWait");
Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
mwait->mbarrier()->as<kir::TensorIndex>());
pushBack(
IrBuilder::create<kir::MBarrierWait>(smem_address_ptr, mwait->state()));
}

void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
// If LoadStoreOp has a smem TV in ldstMBarrierTokenMap, then it is a part
// of a circular buffer loop. The kir nodes for arrive_expect_tx and
// mbarrier_wait are added by the circular buffer pass. Otherwise, those
// nodes are added here.
bool is_circular_buffered =
(GpuLower::current()->ldstMBarrierIndexMap().count(ldst) != 0);

if (is_circular_buffered) {
kir::TensorIndex* mbarrier =
GpuLower::current()->ldstMBarrierIndexMap().at(ldst);
Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, _] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// indexing ldst op
Val* out = lowerDstIndex(
ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true);
Expr* new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);

// register new LoadStoreOp with mbarrier
GpuLower::current()->ldstMBarrierIndexMap()[new_ldst] = mbarrier;

GpuLower::current()->propagateExprInfo(ldst, back());
} else {
TensorView* mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst);
Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// arrive and expect_tx mbarrier
Val* state = IrBuilder::create<Val>(DataType::UInt);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));

// indexing ldst op
Val* out = lowerDstIndex(
ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true);
Expr* new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);

GpuLower::current()->propagateExprInfo(ldst, back());
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));
}
}

void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArriveExpectTx*) final;
void handle(const kir::MBarrierWait*) final;
void handle(const kir::AsyncWait*) final;
void handle(const kir::AsyncCommit*) final;
void handle(const kir::BlockSerializeWait*) final;
Expand Down
8 changes: 8 additions & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,14 @@ Val* u32IndexScalarSmemTv(TensorView* smem_tv) {
return u32addr;
}

Val* u32IndexScalarSmemTv(kir::TensorIndex* index) {
auto ptr_address = IrBuilder::addressExpr(index);
auto u32addr = IrBuilder::create<Val>(DataType::SMemAddress);
IrBuilder::create<UnaryOp>(
UnaryOpType::ToUnsignedSmemAddr, u32addr, ptr_address);
return u32addr;
}

Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) {
// See the comment above for getGridCommWorkBufferSize.
NVF_ERROR(
Expand Down
4 changes: 4 additions & 0 deletions csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id);
//! indexing special items in shared memory, like mbarrier.
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);

//! Get the uint32_t index of a TensorIndex. This is usually used for
//! initializing a pipeline of mbarriers.
NVF_API Val* u32IndexScalarSmemTv(kir::TensorIndex* index);

//! Get the size of a global sync buffer needed to perform a grid reduction for
//! each axis in bitmap.
Val* getGridSyncBufferSize(const ParallelTypeBitmap& bitmap);
Expand Down

0 comments on commit 752c0fe

Please sign in to comment.