Skip to content

Commit

Permalink
Add Indexing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Aug 22, 2024
1 parent 12db3ee commit 2491171
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 28 deletions.
132 changes: 104 additions & 28 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1404,45 +1404,121 @@ void IndexLowering::handleGroupedGridWelford(
}

void IndexLowering::handle(const kir::MBarrierInit* minit) {
auto minit_indexed = IrBuilder::create<kir::MBarrierInit>(
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>()),
minit->threadCount());
Val* smem_address_ptr = nullptr;

if (minit->mbarrier()->isA<TensorView>()) {
smem_address_ptr =
lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as<TensorView>());
} else if (minit->mbarrier()->isA<kir::TensorIndex>()) {
smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
minit->mbarrier()->as<kir::TensorIndex>());
} else {
NVF_ERROR(false, "Unexpected MBarrierInit value.");
}
kir::MBarrierInit* minit_indexed = IrBuilder::create<kir::MBarrierInit>(
smem_address_ptr, minit->threadCount());
pushBack(minit_indexed);
GpuLower::current()->propagateExprInfo(minit, minit_indexed);
}

void IndexLowering::handle(const kir::MBarrierInvalidate* minval) {
auto minval_indexed = IrBuilder::create<kir::MBarrierInvalidate>(
lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as<TensorView>()));
Val* smem_address_ptr = nullptr;

if (minval->mbarrier()->isA<TensorView>()) {
smem_address_ptr =
lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as<TensorView>());
} else if (minval->mbarrier()->isA<kir::TensorIndex>()) {
smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
minval->mbarrier()->as<kir::TensorIndex>());
} else {
NVF_ERROR(false, "Unexpected MBarrierInval value.");
}
kir::MBarrierInvalidate* minval_indexed =
IrBuilder::create<kir::MBarrierInvalidate>(smem_address_ptr);
pushBack(minval_indexed);
GpuLower::current()->propagateExprInfo(minval, minval_indexed);
}

void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
// indexing mbarrier
auto mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst);
auto mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// arrive and expect_tx mbarrier
auto state = IrBuilder::create<Val>(DataType::UInt);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
void IndexLowering::handle(
const kir::MBarrierArriveExpectTx* arrive_transaction) {
NVF_ERROR(
arrive_transaction->mbarrier()->isA<kir::TensorIndex>(),
"Expected kir::TensorIndex in MBarrierArriveExpectTx");

Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
arrive_transaction->mbarrier()->as<kir::TensorIndex>());
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));
arrive_transaction->state(),
smem_address_ptr,
arrive_transaction->txCount()));
}

// indexing ldst op
auto out = lowerDstIndex(ldst->out(), {}, true);
auto new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);
GpuLower::current()->propagateExprInfo(ldst, back());
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));
void IndexLowering::handle(const kir::MBarrierWait* mwait) {
NVF_ERROR(
mwait->mbarrier()->isA<kir::TensorIndex>(),
"Expected kir::TensorIndex in MBarrierWait");
Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv(
mwait->mbarrier()->as<kir::TensorIndex>());
pushBack(
IrBuilder::create<kir::MBarrierWait>(smem_address_ptr, mwait->state()));
}

void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
// If LoadStoreOp has a smem TV in ldstMBarrierTokenMap, then it is a part
// of a circular buffer loop. The kir nodes for arrive_expect_tx and
// mbarrier_wait are added by the circular buffer pass. Otherwise, those
// nodes are added here.
bool is_circular_buffered =
(GpuLower::current()->ldstMBarrierIndexMap().count(ldst) != 0);

if (is_circular_buffered) {
kir::TensorIndex* mbarrier =
GpuLower::current()->ldstMBarrierIndexMap().at(ldst);
Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, _] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// indexing ldst op
Val* out = lowerDstIndex(
ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true);
Expr* new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);

// register new LoadStoreOp with mbarrier
GpuLower::current()->ldstMBarrierIndexMap()[new_ldst] = mbarrier;

GpuLower::current()->propagateExprInfo(ldst, back());
} else {
TensorView* mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst);
Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);

// gmem indexing and expect_bytes for mbarrier
auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex(
ldst, mbarrier_index, for_loops_, rotated_loop_);

// arrive and expect_tx mbarrier
Val* state = IrBuilder::create<Val>(DataType::UInt);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));

// indexing ldst op
Val* out = lowerDstIndex(
ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true);
Expr* new_ldst =
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);

GpuLower::current()->propagateExprInfo(ldst, back());
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));
}
}

void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const kir::GridSync*) final;
void handle(const kir::MBarrierInit*) final;
void handle(const kir::MBarrierInvalidate*) final;
void handle(const kir::MBarrierArriveExpectTx*) final;
void handle(const kir::MBarrierWait*) final;
void handle(const kir::AsyncWait*) final;
void handle(const kir::AsyncCommit*) final;
void handle(const kir::BlockSerializeWait*) final;
Expand Down
8 changes: 8 additions & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,14 @@ Val* u32IndexScalarSmemTv(TensorView* smem_tv) {
return u32addr;
}

Val* u32IndexScalarSmemTv(kir::TensorIndex* index) {
auto ptr_address = IrBuilder::addressExpr(index);
auto u32addr = IrBuilder::create<Val>(DataType::SMemAddress);
IrBuilder::create<UnaryOp>(
UnaryOpType::ToUnsignedSmemAddr, u32addr, ptr_address);
return u32addr;
}

Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) {
// See the comment above for getGridCommWorkBufferSize.
NVF_ERROR(
Expand Down
4 changes: 4 additions & 0 deletions csrc/device_lower/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id);
//! indexing special items in shared memory, like mbarrier.
NVF_API Val* u32IndexScalarSmemTv(TensorView* tv);

//! Get the uint32_t index of a TensorIndex. This is usually used for
//! initializing a pipeline of mbarriers.
NVF_API Val* u32IndexScalarSmemTv(kir::TensorIndex* index);

//! Get the size of a global sync buffer needed to perform a grid reduction for
//! each axis in bitmap.
Val* getGridSyncBufferSize(const ParallelTypeBitmap& bitmap);
Expand Down

0 comments on commit 2491171

Please sign in to comment.