diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 0e0cd87d316..dad23f7e953 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1404,45 +1404,121 @@ void IndexLowering::handleGroupedGridWelford( } void IndexLowering::handle(const kir::MBarrierInit* minit) { - auto minit_indexed = IrBuilder::create( - lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as()), - minit->threadCount()); + Val* smem_address_ptr = nullptr; + + if (minit->mbarrier()->isA()) { + smem_address_ptr = + lower_utils::u32IndexScalarSmemTv(minit->mbarrier()->as()); + } else if (minit->mbarrier()->isA()) { + smem_address_ptr = lower_utils::u32IndexScalarSmemTv( + minit->mbarrier()->as()); + } else { + NVF_ERROR(false, "Unexpected MBarrierInit value."); + } + kir::MBarrierInit* minit_indexed = IrBuilder::create( + smem_address_ptr, minit->threadCount()); pushBack(minit_indexed); GpuLower::current()->propagateExprInfo(minit, minit_indexed); } void IndexLowering::handle(const kir::MBarrierInvalidate* minval) { - auto minval_indexed = IrBuilder::create( - lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as())); + Val* smem_address_ptr = nullptr; + + if (minval->mbarrier()->isA()) { + smem_address_ptr = + lower_utils::u32IndexScalarSmemTv(minval->mbarrier()->as()); + } else if (minval->mbarrier()->isA()) { + smem_address_ptr = lower_utils::u32IndexScalarSmemTv( + minval->mbarrier()->as()); + } else { + NVF_ERROR(false, "Unexpected MBarrierInval value."); + } + kir::MBarrierInvalidate* minval_indexed = + IrBuilder::create(smem_address_ptr); pushBack(minval_indexed); GpuLower::current()->propagateExprInfo(minval, minval_indexed); } -void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) { - // indexing mbarrier - auto mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst); - auto mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier); - - // gmem indexing and expect_bytes for mbarrier - auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex( - ldst, mbarrier_index, for_loops_, rotated_loop_); - - // arrive and expect_tx mbarrier - auto state = IrBuilder::create(DataType::UInt); - pushBack(IrBuilder::create( - state, MemoryType::Local, ldst->container()->oneVal())); +void IndexLowering::handle( + const kir::MBarrierArriveExpectTx* arrive_transaction) { + NVF_ERROR( + arrive_transaction->mbarrier()->isA(), + "Expected kir::TensorIndex in MBarrierArriveExpectTx"); + + Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv( + arrive_transaction->mbarrier()->as()); pushBack(IrBuilder::create( - state, mbarrier_index, expect_bytes)); + arrive_transaction->state(), + smem_address_ptr, + arrive_transaction->txCount())); +} - // indexing ldst op - auto out = lowerDstIndex(ldst->out(), {}, true); - auto new_ldst = - IrBuilder::create(ldst->opType(), out, in, ldst->cacheOp()) - ->withPredicate(ldst->predicate()); - pushBack(new_ldst); - GpuLower::current()->propagateExprInfo(ldst, back()); - // wait mbarrier - pushBack(IrBuilder::create(mbarrier_index, state)); +void IndexLowering::handle(const kir::MBarrierWait* mwait) { + NVF_ERROR( + mwait->mbarrier()->isA(), + "Expected kir::TensorIndex in MBarrierWait"); + Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv( + mwait->mbarrier()->as()); + pushBack( + IrBuilder::create(smem_address_ptr, mwait->state())); +} + +void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) { + // If LoadStoreOp has a smem TV in ldstMBarrierTokenMap, then it is a part + // of a circular buffer loop. The kir nodes for arrive_expect_tx and + // mbarrier_wait are added by the circular buffer pass. Otherwise, those + // nodes are added here. + bool is_circular_buffered = + (GpuLower::current()->ldstMBarrierIndexMap().count(ldst) != 0); + + if (is_circular_buffered) { + kir::TensorIndex* mbarrier = + GpuLower::current()->ldstMBarrierIndexMap().at(ldst); + Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier); + + // gmem indexing and expect_bytes for mbarrier + auto [in, _] = Index::getCpAsyncBulkGmemIndex( + ldst, mbarrier_index, for_loops_, rotated_loop_); + + // indexing ldst op + Val* out = lowerDstIndex( + ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true); + Expr* new_ldst = + IrBuilder::create(ldst->opType(), out, in, ldst->cacheOp()) + ->withPredicate(ldst->predicate()); + pushBack(new_ldst); + + // register new LoadStoreOp with mbarrier + GpuLower::current()->ldstMBarrierIndexMap()[new_ldst] = mbarrier; + + GpuLower::current()->propagateExprInfo(ldst, back()); + } else { + TensorView* mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst); + Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier); + + // gmem indexing and expect_bytes for mbarrier + auto [in, expect_bytes] = Index::getCpAsyncBulkGmemIndex( + ldst, mbarrier_index, for_loops_, rotated_loop_); + + // arrive and expect_tx mbarrier + Val* state = IrBuilder::create(DataType::UInt); + pushBack(IrBuilder::create( + state, MemoryType::Local, ldst->container()->oneVal())); + pushBack(IrBuilder::create( + state, mbarrier_index, expect_bytes)); + + // indexing ldst op + Val* out = lowerDstIndex( + ldst->out(), /*override_index=*/{}, /*generate_pointer=*/true); + Expr* new_ldst = + IrBuilder::create(ldst->opType(), out, in, ldst->cacheOp()) + ->withPredicate(ldst->predicate()); + pushBack(new_ldst); + + GpuLower::current()->propagateExprInfo(ldst, back()); + // wait mbarrier + pushBack(IrBuilder::create(mbarrier_index, state)); + } } void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) { diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 448933d3835..feae4ccc733 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -74,6 +74,8 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::GridSync*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; + void handle(const kir::MBarrierArriveExpectTx*) final; + void handle(const kir::MBarrierWait*) final; void handle(const kir::AsyncWait*) final; void handle(const kir::AsyncCommit*) final; void handle(const kir::BlockSerializeWait*) final; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index a75f95c4ab0..eabeaa3f097 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -838,6 +838,14 @@ Val* u32IndexScalarSmemTv(TensorView* smem_tv) { return u32addr; } +Val* u32IndexScalarSmemTv(kir::TensorIndex* index) { + auto ptr_address = IrBuilder::addressExpr(index); + auto u32addr = IrBuilder::create(DataType::SMemAddress); + IrBuilder::create( + UnaryOpType::ToUnsignedSmemAddr, u32addr, ptr_address); + return u32addr; +} + Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) { // See the comment above for getGridCommWorkBufferSize. NVF_ERROR( diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index fa192dca180..6661a650b47 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -304,6 +304,10 @@ bool isExtentEqualToMaxParallelTypeExtent(const IterDomain* id); //! indexing special items in shared memory, like mbarrier. NVF_API Val* u32IndexScalarSmemTv(TensorView* tv); +//! Get the uint32_t index of a TensorIndex. This is usually used for +//! initializing a pipeline of mbarriers. +NVF_API Val* u32IndexScalarSmemTv(kir::TensorIndex* index); + //! Get the size of a global sync buffer needed to perform a grid reduction for //! each axis in bitmap. Val* getGridSyncBufferSize(const ParallelTypeBitmap& bitmap);