diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index a0d38674095..c69f2755078 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -23,11 +23,13 @@ namespace { // buffer tensor is on smem, in which case it would otherwise require // an additional predicate to guard buffer overruns. When it's on // gmem, that isn't the case, so it does not need to create an -// epilogue loop. +// epilogue loop. For TMA cpAsyncBulk, there is always an epilogue loop. bool requireEpilogue(const std::vector& exprs) { return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { - return expr->input(0)->as()->getMemoryType() == - MemoryType::Shared; + return (expr->input(0)->as()->getMemoryType() == + MemoryType::Shared) || + (expr->as()->opType() == + LoadStoreOpType::CpAsyncBulkTensorTile); }); } @@ -118,8 +120,8 @@ class CircularBufferLoopCloner : public kir::IrVisitor { start, stop, /*step=*/GpuLower::current()->kernel()->oneVal(), - /*step=*/false, - /*vectorize=*/nullptr, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, circular_buffer_loop_->isUnrollRequired(), loop_type_); @@ -131,16 +133,24 @@ class CircularBufferLoopCloner : public kir::IrVisitor { ? cloned_top_level_loop_ : IrBuilder::create(fl); - cloned_scopes_.push_back(&cloned_loop->body()); + // Add to stack + for_loop_stack_.push_back(cloned_loop); + // Process for-loop kir::IrVisitor::handle(fl); - cloned_scopes_.pop_back(); + // Pop from stack + for_loop_stack_.pop_back(); + // Specific handling of for-loop + processForLoop(cloned_loop); + } + + virtual void processForLoop(ForLoop* cloned_loop) { // Add the cloned loop into the parent loop body only when the // cloned loop contains expressions. - if (!cloned_loop->body().empty() && !cloned_scopes_.empty()) { - cloned_scopes_.back()->push_back(cloned_loop); + if (!cloned_loop->body().empty() && !for_loop_stack_.empty()) { + for_loop_stack_.back()->body().push_back(cloned_loop); } } @@ -149,7 +159,7 @@ class CircularBufferLoopCloner : public kir::IrVisitor { } void dispatch(Expr* expr) override { - // skip expression if it is in exclude set + // Skip expression if it is in exclude set if (exclude_.count(expr) > 0) { return; } @@ -160,8 +170,13 @@ class CircularBufferLoopCloner : public kir::IrVisitor { return; } - NVF_ERROR(!cloned_scopes_.empty()); + NVF_ERROR(!for_loop_stack_.empty()); + + // Specific expression handling + processExpr(expr); + } + virtual void processExpr(Expr* expr) { switch (loop_type_) { case CircularBufferLoopStage::Prolog: { // In Prologue, only copy the load expressions. @@ -169,19 +184,19 @@ class CircularBufferLoopCloner : public kir::IrVisitor { // circular buffered TVs (e.g., buffer initialization). TensorView* out_tv = ir_utils::getTvOutput(expr); if (circular_buffer_load_tvs_.count(out_tv) > 0) { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); } break; } case CircularBufferLoopStage::Main: { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); break; } case CircularBufferLoopStage::Epilog: { // In Epilogue, copy everything except circular buffer load expressions. TensorView* out_tv = ir_utils::getTvOutput(expr); if (circular_buffer_load_tvs_.count(out_tv) == 0) { - cloned_scopes_.back()->push_back(expr); + for_loop_stack_.back()->body().push_back(expr); } break; } @@ -191,17 +206,705 @@ class CircularBufferLoopCloner : public kir::IrVisitor { } } - private: + protected: ForLoop* circular_buffer_loop_ = nullptr; const std::vector& circular_buffer_load_exprs_; const CircularBufferLoopStage loop_type_; std::unordered_set circular_buffer_load_tvs_; ForLoop* cloned_top_level_loop_ = nullptr; - std::deque cloned_scopes_; + std::vector for_loop_stack_; const std::unordered_set& exclude_; }; +// TODO Replace with elect_sync ptx +// TMA operation only a single thread is necessary to launch TMA operations. +// This function creates kir::IfThenElse with the following predicate: +// threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 +kir::IfThenElse* createThreadPredicatedIfThenElse() { + Val* zero_val = IrBuilder::create(0L, PrimDataType::UInt); + Val* if_predicate_expr = IrBuilder::logicalAndExpr( + IrBuilder::logicalAndExpr( + IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), zero_val), + IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDy), zero_val)), + IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDz), zero_val)); + + kir::IfThenElse* if_expr = IrBuilder::create( + IrBuilder::create(if_predicate_expr)); + + return if_expr; +} + +// Description: +// Replicates circular buffer loops for Prologue, Main, and +// Epilogue. Prologue only copies the load expressions of circular +// buffered tensors, whereas Epilogue does any expression other than +// the loads. Main copies everything. +// +// Loop Structure Overview: +// Pre-prologue loop: +// - Allocate shared memory for mbarriers and mbarrier tokens +// - Initialize mbarrier for all stages +// +// Prologue loop: +// - if selected_thread: +// - Issue cp async bulks for all but last stage +// +// Main loop: +// - if selected_thread: +// - Issue next cp async bulk for available stage +// - All threads wait until tma operation arrives +// - Copy body without +// - shared memory allocations +// - mbarrier_init exprs +// - mbarrier_inval exprs +// +// Epilogue loop: +// - All threads wait until tma operation arrives +// - Copy body without +// - shared memory allocations +// - issuing cp async bulk operations +// - mbarrier_init exprs +// - mbarrier_inval exprs +// +// Post-epilogue loop: +// - if selected_thread: +// - Invalidated mbarrier for all stages +// +// Detailed Pseudo-Code: +// Pre-Prologue loop: +// __shared__ __mbarrier_t barriers[num_stages]; +// __shared__ __mbarrier_token_t tokens[num_stages]; +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// for (int64_t loop_index : irange(stages)) { +// mbarrier_init(mbarrier[loop_index], number_of_arrival_threads); +// } +// } +// +// Prologue loop: +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// for (int64_t loop_index : irange(stages-1)) { +// tokens[loop_index] = mbarrier::arriveExpectTx(mbarrier[loop_index]); +// cpAsyncBulk(mbarriers[loop_index], ...); +// } +// } +// +// Main loop: +// for (int64_t loop_index : irange(N-(stages-1))) { +// current_stage = loop_index % stage_depth +// load_stage = (loop_index + (stage_depth - 1)) % stage_depth) +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// token[load_stage] = +// mbarrier::arriveExpectTx(mbarrier[load_stage]); +// cpAsyncBulk(mbarrier[load_stage], ...); +// } +// mbarrier::wait(token[current_stage]); +// +// Clone remaining operations +// } +// +// Epilogue loop: +// for (int64_t loop_index : irange(N-(stages-1), N)) { +// current_stage = loop_index % stage_depth +// mbarrier::wait(token[current_stage]); +// +// Clone remaining operations +// } +// +// Post-Epilogue loop: +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// for (int64_t loop_index : irange(stages)) { +// mbarrier_inval(mbarrier[loop_index]); +// } +// } +// +class TmaCircularBufferLoopCloner : public CircularBufferLoopCloner { + public: + static ForLoop* clone( + ForLoop* circular_buffer_loop, + const std::vector& circular_buffer_load_exprs, + CircularBufferLoopStage loop_type, + const std::unordered_set& exclude = {}) { + TmaCircularBufferLoopCloner cloner( + circular_buffer_loop, circular_buffer_load_exprs, loop_type, exclude); + cloner.duplicate(); + return cloner.cloned_top_level_loop_; + } + + private: + TmaCircularBufferLoopCloner( + ForLoop* circular_buffer_loop, + const std::vector& circular_buffer_load_exprs, + CircularBufferLoopStage loop_type, + const std::unordered_set& exclude) + : CircularBufferLoopCloner( + circular_buffer_loop, + circular_buffer_load_exprs, + loop_type, + exclude) {} + + void processForLoop(ForLoop* cloned_loop) final { + // Skip if there is not an active for-loop structure + if (for_loop_stack_.empty()) { + return; + } + + if (!cloned_loop->body().empty()) { + if (mbarrier_arrive_tx_ == nullptr || for_loop_stack_.size() > 1) { + // Add cloned for_loop when mbarrier_arrive_tx_ is not active or + // we are within a nested for-loop structure + for_loop_stack_.back()->body().push_back(cloned_loop); + } else { + // mbarrier::arriveExpectTx and TMA load operations occur in prologue + // and main loops. + NVF_ERROR(for_loop_stack_.front() == cloned_top_level_loop_); + addTmaLoadBlock(cloned_loop); + } + } + + // mbarrier::wait occurs in Main and Epilogue loops. + if (mbarrier_wait_ != nullptr && for_loop_stack_.size() == 1) { + NVF_ERROR(for_loop_stack_.back() == cloned_top_level_loop_); + addSynchronousMbarrierWait(); + } + } + + void processExpr(Expr* expr) final { + bool mbarrier_token_exists = + GpuLower::current()->ldstMBarrierTokenMap().count(expr) != 0; + + bool is_ignorable_tma_smem_alloc = + (GpuLower::current()->mBarrierTokenSmemAllocSet().count(expr) != 0); + + bool is_ignorable_mbarrier_init = + (expr->isA() && mbarrier_token_exists); + + bool is_ignorable_mbarrier_inval = + (expr->isA() && mbarrier_token_exists); + + // Short-Circuit + switch (loop_type_) { + case CircularBufferLoopStage::Prolog: { + // Skip expression if it is not circular buffer expression + TensorView* out_tv = ir_utils::getTvOutput(expr); + bool is_circular_buffer_load_expr = std::any_of( + circular_buffer_load_exprs_.begin(), + circular_buffer_load_exprs_.end(), + [out_tv](Expr* load_expr) { + TensorView* circular_buffer_tv = ir_utils::getTvOutput(load_expr); + NVF_ERROR(circular_buffer_tv != nullptr); + return out_tv == circular_buffer_tv; + }); + if (!is_circular_buffer_load_expr) { + return; + } + // NOTE: There can be circular buffered TVs without TMA load exprs. + if (!mbarrier_token_exists) { + for_loop_stack_.back()->body().push_back(expr); + return; + } + break; + } + case CircularBufferLoopStage::Main: + case CircularBufferLoopStage::Epilog: { + // Skip shared memory allocation, mbarrier initialize and mbarrier + // invalidate for main and epilog loops + if (is_ignorable_tma_smem_alloc || is_ignorable_mbarrier_init || + is_ignorable_mbarrier_inval) { + return; + } + + // Add expression if not circular-buffered load store operation + if (!expr->isA() || !mbarrier_token_exists) { + for_loop_stack_.back()->body().push_back(expr); + return; + } + break; + } + case CircularBufferLoopStage::NotApplicable: { + NVF_ERROR(false, "Unsupported loop mode, got: ", loop_type_); + } + } + + switch (loop_type_) { + case CircularBufferLoopStage::Prolog: { + return handlePrologueLoop(expr); + } + case CircularBufferLoopStage::Main: { + return handleMainLoop(expr); + } + case CircularBufferLoopStage::Epilog: { + return handleEpilogLoop(expr); + } + case CircularBufferLoopStage::NotApplicable: { + NVF_ERROR(false, "Unsupported loop mode, got: ", loop_type_); + } + } + } + + // Replace cpAsyncBulk type LoadStoreOp with: + // if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + // for (int64_t loop_idx : irange(stages-1)) { + // tokens[loop_idx] = + // mbarrier::arriveExpectTx(mbarrier[loop_idx]) + // cpAsyncBulk(mbarrier[loop_idx], ...); + // } + // } + void handlePrologueLoop(Expr* expr) { + NVF_ERROR(expr != nullptr); + + // Skip if not LoadStoreOp expression + if (!expr->isA()) { + return; + } + + LoadStoreOp* ldst = expr->as(); + + // There should be a single mbarrier_arrive_tx_ for all ldst in current + // stage. + NVF_ERROR(mbarrier_arrive_tx_ == nullptr); + mbarrier_arrive_tx_ = createMbarrierArriveExpectTx( + ldst, cloned_top_level_loop_->indexOrStartIfTrivial()); + + // Clone LoadStoreOp and map it to mbarrier alloc + Expr* new_ldst = + IrBuilder::create( + ldst->opType(), ldst->out(), ldst->in(), ldst->cacheOp()) + ->withPredicate(ldst->predicate()); + + // Register mbarrier object to be used with new LoadStoreOp + // from prolog loop + NVF_ERROR(mbarrier_arrive_tx_->mbarrier()->isA()); + GpuLower::current()->ldstMBarrierIndexMap().emplace( + new_ldst, mbarrier_arrive_tx_->mbarrier()->as()); + + // If last cloned scope is the cloned_top_level_loop body, then add + // mbarrier::arriveExpectTx and new loadStoreOp. + int64_t active_for_loops = std::count_if( + for_loop_stack_.begin(), for_loop_stack_.end(), [](ForLoop* fl) { + return fl->iter_domain()->getParallelType() == ParallelType::Serial; + }); + if (active_for_loops == 1) { + return addTmaLoadBlock(new_ldst); + } + + // Otherwise, we are in a nested for-loop and should wait until we + // return to top-level for loop. + for_loop_stack_.back()->body().push_back(new_ldst); + } + + // Handle cpAsyncBulk type LoadStoreOp that is registered with token + // + // current_compute_stage = loop_index % stage_depth + // current_load_stage = (loop_index + (stage_depth - 1)) % stage_depth) + // + // Replace LoadStoreOp with: + // if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + // tokens[current_load_stage] = + // mbarrier::arriveExpectTx(mbarrier[current_load_stage]); + // cpAsyncBulk(mbarrier[current_load_stage], ...); + // } + // mbarrier::wait(token[current_stage]); + // + // Where mbarrier and token are shared memory arrays bound to the + // LoadStoreOp + void handleMainLoop(Expr* expr) { + NVF_ERROR(expr != nullptr && expr->isA()); + + int64_t stage_depth = + GpuLower::current()->circularBufferInfo().getStageDepthFor( + circular_buffer_loop_->iter_domain()); + + if (current_compute_stage_ == nullptr) { + current_compute_stage_ = IrBuilder::modExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::create(stage_depth, PrimDataType::Index)); + kir::Allocate* current_compute_stage_alloc = + IrBuilder::create( + current_compute_stage_, + MemoryType::Local, + IrBuilder::create(1L, PrimDataType::Index), + /*zero_init=*/false); + cloned_top_level_loop_->body().push_back(current_compute_stage_alloc); + cloned_top_level_loop_->body().push_back( + current_compute_stage_->definition()); + } + + if (current_load_stage_ == nullptr) { + current_load_stage_ = IrBuilder::modExpr( + IrBuilder::addExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::subExpr( + IrBuilder::create(stage_depth, PrimDataType::Index), + IrBuilder::create(1L, PrimDataType::Index))), + IrBuilder::create(stage_depth, PrimDataType::Index)); + kir::Allocate* current_load_stage_alloc = + IrBuilder::create( + current_load_stage_, + MemoryType::Local, + IrBuilder::create(1L, PrimDataType::Index), + /*zero_init=*/false); + cloned_top_level_loop_->body().push_back(current_load_stage_alloc); + cloned_top_level_loop_->body().push_back( + current_load_stage_->definition()); + } + + LoadStoreOp* ldst = expr->as(); + + // There should be a single mbarrier_arrive_tx_ for all ldst in current + // stage. + NVF_ERROR(mbarrier_arrive_tx_ == nullptr); + mbarrier_arrive_tx_ = + createMbarrierArriveExpectTx(ldst, current_load_stage_); + + // Register mbarrier object to be used with LoadStoreOp + // from main loop + NVF_ERROR(mbarrier_arrive_tx_->mbarrier()->isA()); + GpuLower::current()->ldstMBarrierIndexMap().emplace( + ldst, mbarrier_arrive_tx_->mbarrier()->as()); + + // Construct mBarrier::wait for current stage + NVF_ERROR( + mbarrier_wait_ == nullptr, + "Expected mbarrier_wait to inactive for current TMA operation"); + mbarrier_wait_ = createMbarrierWait(ldst, current_compute_stage_); + + // If last cloned scope is the cloned_top_level_loop body, then add + // mbarrier::arriveExpectTx and new loadStoreOp. + int64_t active_for_loops = std::count_if( + for_loop_stack_.begin(), for_loop_stack_.end(), [](ForLoop* fl) { + return fl->iter_domain()->getParallelType() == ParallelType::Serial; + }); + if (active_for_loops == 1) { + addTmaLoadBlock(ldst); + return; + } + + // Otherwise, we are in a nested for-loop and should wait until we + // return to top-level for loop. + for_loop_stack_.back()->body().push_back(ldst); + } + + void handleEpilogLoop(Expr* expr) { + NVF_ERROR(expr != nullptr && expr->isA()); + + // Construct mBarrier::wait for epilogue + LoadStoreOp* ldst = expr->as(); + int64_t stage_depth = + GpuLower::current()->circularBufferInfo().getStageDepthFor( + circular_buffer_loop_->iter_domain()); + Val* epilogue_compute_stage = IrBuilder::modExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::create(stage_depth, PrimDataType::Index)); + + NVF_ERROR( + mbarrier_wait_ == nullptr, + "Expected mbarrier_wait to inactive for current TMA operation"); + mbarrier_wait_ = createMbarrierWait(ldst, epilogue_compute_stage); + } + + // This function selects a single thread to launch tma load and mbarrier + // arrive_expected_tx operations. + // + // Pseudo-code example: + // if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + // tokens[next_stage] = + // mbarrier::arriveExpectTx(mbarrier[next_stage]); + // cpAsyncBulk(mbarrier[next_stage], ...); + // } + void addTmaLoadBlock(Expr* expr) { + NVF_ERROR(expr != nullptr); + kir::IfThenElse* if_expr = createThreadPredicatedIfThenElse(); + Scope& body = if_expr->thenBody(); + body.push_back(mbarrier_arrive_tx_); + body.push_back(expr); + for_loop_stack_.back()->body().push_back(if_expr); + mbarrier_arrive_tx_ = nullptr; + } + + // This function adds mbarrier::wait to top level cloned loop. + // + // Pseudo-code example: + // __syncthreads(); + // mbarrier::wait(mbarriers[stage], mbarrier_tokens[stage]); + void addSynchronousMbarrierWait() { + NVF_ERROR(mbarrier_wait_ != nullptr); + + // The Mbarrier Wait condition is a single thread and the expected bytes + // for TMA operation. Since the total number of threads is unknown, we + // use a block sync to prevent race conditions. + kir::BlockSync* sync_expr = + IrBuilder::create(/*war_sync=*/true); + cloned_top_level_loop_->body().push_back(sync_expr); + + // TODO Use total number of threads of CTA with mbarrier_wait + // TODO Create analysis to determine when block sync is required + cloned_top_level_loop_->body().push_back(mbarrier_wait_); + + mbarrier_wait_ = nullptr; + } + + // This function creates kir::MBarrierArriveExpectTx for given LoadStoreOp and + // circular buffer stage. + // + // Example: + // __shared__ __mbarrier_t barriers[num_stages]; + // __shared__ __mbarrier_token_t tokens[num_stages]; + // for(nvfuser_index_t stage = 0; stage < num_stages; ++stage) { + // if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + // tokens[stage] = + // mbarrier::arriveExpectTX(toSmem((&barriers[stage])), + // expected_bytes); + // } + // } + kir::MBarrierArriveExpectTx* createMbarrierArriveExpectTx( + LoadStoreOp* ldst, + Val* loop_index) { + NVF_ERROR(ldst != nullptr); + NVF_ERROR(loop_index != nullptr); + + TensorView* consumer_tv = ldst->out()->as(); + NVF_ERROR( + GpuLower::current()->consumerToTMAInfo().count(consumer_tv), + "Unable to find TMA info for consumer_tv: ", + consumer_tv->toString()); + + // Get expected bytes for given TMA load operation. + const TMAInfo& tma_info = + GpuLower::current()->consumerToTMAInfo().at(consumer_tv); + Val* expected_bytes = tma_info.tileSizeBytes(); + + // The expected_bytes for mbarrier::arriveExpectTX must account for all TMA + // load operations launched for each circular buffer stage. We take the + // product of all coordinate TMA iterDomains to the right of the circular + // buffer axis. + const std::vector& leaf_domain = consumer_tv->getLoopDomain(); + for (size_t idx = consumer_tv->getComputeAtPosition(); + idx < leaf_domain.size(); + ++idx) { + IterDomain* id = leaf_domain.at(idx); + if (!isParallelTypeThread(id->getParallelType()) && + id->getParallelType() != ParallelType::Bulk) { + expected_bytes = + SimplifyingIrBuilder::mulExpr(expected_bytes, id->extent()); + } + } + expected_bytes = + SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, expected_bytes); + + auto is_multiple_of_16B = SimplifyingIrBuilder::eqExpr( + SimplifyingIrBuilder::modExpr( + expected_bytes, IrBuilder::create(16, DataType::Index)), + expected_bytes->fusion()->zeroVal()); + GpuLower::current()->validate( + is_multiple_of_16B, + "The expected bytes must be a multiple of 16 bytes, but ", + expected_bytes, + " is not."); + + // Get mbarrier for this circular buffer stage. + TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + kir::TensorIndex* stage_mbarrier = + IrBuilder::create(all_mbarriers, loop_index); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarrier_tokens = + GpuLower::current()->ldstMBarrierTokenMap().at(ldst); + kir::TensorIndex* stage_token = + IrBuilder::create(all_mbarrier_tokens, loop_index); + + kir::MBarrierArriveExpectTx* mbarrier_arrive_tx = + IrBuilder::create( + stage_token, stage_mbarrier, expected_bytes); + + return mbarrier_arrive_tx; + } + + // This function creates kir::MBarrierWait for given LoadStoreOp and circular + // buffer stage. + kir::MBarrierWait* createMbarrierWait(LoadStoreOp* ldst, Val* loop_index) { + NVF_ERROR(ldst != nullptr); + NVF_ERROR(loop_index != nullptr); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + kir::TensorIndex* stage_mbarrier = + IrBuilder::create(all_mbarriers, loop_index); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarrier_tokens = + GpuLower::current()->ldstMBarrierTokenMap().at(ldst); + kir::TensorIndex* stage_token = + IrBuilder::create(all_mbarrier_tokens, loop_index); + + kir::MBarrierWait* mbarrier_wait = + IrBuilder::create(stage_mbarrier, stage_token); + return mbarrier_wait; + } + + private: + // Mbarrier_Wait to add to cloned_top_level_loop + kir::MBarrierWait* mbarrier_wait_ = nullptr; + + // Mbarrier_ArriveExpectTx to add to cloned_top_level_loop + kir::MBarrierArriveExpectTx* mbarrier_arrive_tx_ = nullptr; + + // current_stage_index = (loop_index % stages) + Val* current_compute_stage_ = nullptr; + + // next_stage_index = (loop_index + (stages-1)) % stages + Val* current_load_stage_ = nullptr; +}; + +// This visitor class gathers the shared memory allocations for tokens and +// mbarrier objects. +class GatherMBarrierAllocations : public kir::IrVisitor { + public: + static std::vector create(ForLoop* circular_buffer_loop) { + return GatherMBarrierAllocations().run(circular_buffer_loop); + } + + private: + GatherMBarrierAllocations() = default; + + using kir::IrVisitor::handle; + + std::vector run(ForLoop* circular_buffer_loop) { + handle(circular_buffer_loop); + return new_exprs_; + } + + void handle(ForLoop* fl) final { + kir::IrVisitor::handle(fl); + } + + void handle(kir::IfThenElse* ite) final { + NVF_ERROR(false, "No IfThenElse should exist yet"); + } + + void dispatch(Expr* expr) final { + if (expr->isA() || expr->isA()) { + kir::IrVisitor::dispatch(expr); + return; + } + + // Short-Circuit: Handle only allocate nodes + if (!expr->isA()) { + return; + } + + // Short-Circuit: Handle shared memory allocations + kir::Allocate* alloc = expr->as(); + if (alloc->memoryType() != MemoryType::Shared) { + return; + } + + // Short-Circuit: Handle shared memory allocations for mbarrier + if (GpuLower::current()->mBarrierTokenSmemAllocSet().count(alloc) == 0) { + return; + } + + // Add shared memory allocations for mbarrier and mbarrier tokens + new_exprs_.push_back(expr); + } + + private: + std::vector new_exprs_; +}; + +// This function creates kir::Loop with range based on stage depth. It is +// used for mbarrier initialization and invalidation. +ForLoop* createStageDepthForLoop(ForLoop* circular_buffer_loop) { + int64_t stage_depth = + GpuLower::current()->circularBufferInfo().getStageDepthFor( + circular_buffer_loop->iter_domain()); + + Val* loop_start = IrBuilder::create(0L, PrimDataType::Index); + Val* loop_index = IrBuilder::create(PrimDataType::Index); + Val* loop_stop = IrBuilder::create(stage_depth, DataType::Index); + IterDomainBuilder loop_domain_builder(loop_start, loop_stop); + + ForLoop* loop = IrBuilder::create( + loop_domain_builder.build(), + loop_index, + loop_start, + loop_stop, + /*step=*/GpuLower::current()->kernel()->oneVal(), + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable); + + return loop; +} + +// This helper function creates the pre-prologue and post-epilogue for loops. +// +// The pre-prologue for loop moves the allocation of mbarriers and its tokens +// outside of the main loop. +// +// Expected result: +// Allocate mbarriers and tokens in shared memory +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// for (unsigned i = 0; i < stages; ++i) { +// mbarrier::init(...); +// } +// } +// +// The post-epilogue for loop releases mbarriers after TMA memory +// operations. +// +// Expected result: +// if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { +// for (unsigned i = 0; i < stages; ++i) { +// mbarrier::inval(...); +// } +// } +// +kir::IfThenElse* createCpAsyncBulkFixtures( + ForLoop* circular_buffer_loop, + const std::vector& circular_buffer_load_exprs, + bool is_pre_prologue_stage) { + // Construct predicate to select a single thread. + kir::IfThenElse* if_expr = createThreadPredicatedIfThenElse(); + + // Construct ForLoop + ForLoop* loop = createStageDepthForLoop(circular_buffer_loop); + + for (const Expr* ldst : circular_buffer_load_exprs) { + // Short-Circuit: Handle ldst operations associated with mbarrier + if (GpuLower::current()->ldstMBarrierMap().count(ldst) == 0) { + continue; + } + + // Get mbarrier for this circular buffer stage. + TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + kir::TensorIndex* stage_mbarrier = + IrBuilder::create(all_mbarriers, loop->index()); + + if (is_pre_prologue_stage) { + // We expect a single thread to launch transactions and arrive at + // mbarrier_wait. We will use a block sync to handle the remaining + // threads. + kir::MBarrierInit* mbarrier_init = IrBuilder::create( + stage_mbarrier, + /*thread_count=*/IrBuilder::create(1L, PrimDataType::UInt32)); + loop->body().push_back(mbarrier_init); + } else { + // Invalidate the mbarrier for each circular buffer stage. + kir::MBarrierInvalidate* mbarrier_inval = + IrBuilder::create(stage_mbarrier); + loop->body().push_back(mbarrier_inval); + } + } + + if_expr->thenBody().push_back(loop); + return if_expr; +} + using InsertionInfo = std::unordered_map>; class IsCircularBufferLoadLoop : public kir::IrVisitor { @@ -374,11 +1077,74 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - insert(loop, it->second); + auto hasCpAsyncBulk = std::any_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); + + if (hasCpAsyncBulk) { + insertTma(loop, it->second); + } else { + insert(loop, it->second); + } processed_loop_ = loop; insertion_info_.erase(loop); } + void insertTma( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + // Pre-prologue loop: + // - Allocate shared memory for mbarriers and mbarrier tokens + // - Initialize mbarrier for all stages + std::vector smem_allocations = + GatherMBarrierAllocations::create(circular_buffer_loop); + for (Expr* expr : smem_allocations) { + registerInsertBefore(circular_buffer_loop, expr); + } + + kir::IfThenElse* pre_prologue_init = createCpAsyncBulkFixtures( + circular_buffer_loop, loads, /*is_pre_prologue_stage=*/true); + NVF_ERROR(pre_prologue_init != nullptr); + registerInsertBefore(circular_buffer_loop, pre_prologue_init); + + // Block sync is necessary to finish mbarrier initialization. + kir::BlockSync* sync = IrBuilder::create(false); + registerInsertBefore(circular_buffer_loop, sync); + + // Prologue loop: + // - launch only + // - arrive_expect_tx and tma load operations + ForLoop* prologue_loop = TmaCircularBufferLoopCloner::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::Prolog); + registerInsertBefore(circular_buffer_loop, prologue_loop); + + // Main loop: + // - Launch and wait + // - arrive_expect_tx, tma load operations, and mbarrier_wait) + ForLoop* main_loop = TmaCircularBufferLoopCloner::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::Main); + registerReplace(circular_buffer_loop, main_loop); + + // We can use exclude argument in TmaCircularBufferLoopCloner clone to + // avoid duplicating allocations if main loop is trivial. However, this + // causes the warp_reduce pass to fail with persistent kernels because it + // cannot find the allocation for reduction operation. + + // Epilogue loop: + // - wait only + // - mbarrier_wait + ForLoop* epilogue_loop = TmaCircularBufferLoopCloner::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::Epilog); + registerInsertAfter(circular_buffer_loop, epilogue_loop); + + // Post-epilogue loop: + // - if selected_thread: + // - Invalidated mbarrier for all stages + kir::IfThenElse* post_epilogue_inval = createCpAsyncBulkFixtures( + circular_buffer_loop, loads, /*is_pre_prologue_stage=*/false); + NVF_ERROR(post_epilogue_inval != nullptr); + registerInsertAfter(epilogue_loop, post_epilogue_inval); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { ForLoop* prologue_loop = CircularBufferLoopCloner::clone( circular_buffer_loop, loads, CircularBufferLoopStage::Prolog); diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 0861b227cac..bb2f4c6c8a2 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -6,6 +6,7 @@ */ // clang-format on +#include #include #include #include @@ -802,4 +803,678 @@ INSTANTIATE_TEST_SUITE_P( CircularBufferingTest, ::testing::Range(2, 10)); +using TmaCircularBufferingParams = std::tuple; + +class TmaCircularBufferingTest + : public NVFuserFixtureParamTest { + protected: + int64_t number_of_stages = 1; + int64_t tensor_outer_dim = 1; + int64_t tensor_inner_dim = 1; + + void SetUp() override { + number_of_stages = std::get<0>(GetParam()); + tensor_outer_dim = std::get<1>(GetParam()); + tensor_inner_dim = std::get<2>(GetParam()); + + // NOTE: Multiple of 16 required for inner dimension + NVF_ERROR(tensor_inner_dim % 16 == 0); + NVFuserTest::SetUp(); + } + + template + void compare(int64_t tensor_dim, at::Tensor result, at::Tensor reference) { + at::Tensor reference_cpu_data = reference.cpu(); + at::Tensor result_cpu_data = result.cpu(); + + auto reference_cpu = reference_cpu_data.accessor(); + auto result_cpu = result_cpu_data.accessor(); + + constexpr double tolerance = 1e-3; + for (int64_t pos = 0; pos < tensor_dim; ++pos) { + if (fabs((double)result_cpu[pos] - (double)reference_cpu[pos]) > + tolerance) { + std::cout << "[" << pos << "] - result: " << result_cpu[pos] + << " | reference: " << reference_cpu[pos] << std::endl; + } + } + } + + template + void compare( + int64_t tensor_outer_dim, + int64_t tensor_inner_dim, + at::Tensor result, + at::Tensor reference) { + at::Tensor reference_cpu_data = reference.cpu(); + at::Tensor result_cpu_data = result.cpu(); + + auto reference_cpu = reference_cpu_data.accessor(); + auto result_cpu = result_cpu_data.accessor(); + + constexpr double tolerance = 1e-3; + for (int64_t out_pos = 0; out_pos < tensor_outer_dim; ++out_pos) { + for (int64_t in_pos = 0; in_pos < tensor_inner_dim; ++in_pos) { + if (fabs( + (double)reference_cpu[out_pos][in_pos] - + (double)result_cpu[out_pos][in_pos]) > tolerance) { + std::cout << "[" << out_pos << ", " << in_pos + << "] - result: " << result_cpu[out_pos][in_pos] + << " | ref: " << reference_cpu[out_pos][in_pos] + << std::endl; + } + } + } + } +}; + +TEST_P(TmaCircularBufferingTest, SingleDim) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt before applying circular buffer + tv0->computeAt(tv1, 1); + + // Circular Buffer with TMA loads + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, SingleDimUnroll) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t unroll_dim = 4; + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + // [M/bid, bid] -> [M/bid/unroll, unroll, bid] + reference->split(0, unroll_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // ComputeAt + tv0->computeAt(tv1, 1); + + // Apply Unroll + tv1->axis(1)->parallelize(ParallelType::Unroll); + + // Circular Buffer with TMA loads + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + int64_t axis_extent = + ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim); + if (axis_extent < number_of_stages) { + ASSERT_ANY_THROW(fe.runFusion({t0})); + return; + } + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, SingleDimUnswitch) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t unroll_dim = 4; + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + // [M/bid, bid] -> [M/bid/unroll, unroll, bid] + reference->split(0, unroll_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // ComputeAt + tv0->computeAt(tv1, 1); + + // Apply Unswitch + tv1->axis(1)->parallelize(ParallelType::Unswitch); + + // Circular Buffer with TMA loads + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + int64_t axis_extent = + ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim); + if (axis_extent < number_of_stages) { + ASSERT_ANY_THROW(fe.runFusion({t0})); + return; + } + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, MultiDim) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr int64_t tma_outer_dim = 4; + constexpr int64_t tma_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, tma_inner_dim); + // [M, N/bid, bid] -> [M/bod, bod, N/bid, bid] + reference->split(0, tma_outer_dim); + // [M/bod, bod, N/bid, bid] -> [M/bod, N/bid, bod, bid] + reference->reorder({{-2, -3}}); + + // Propagate TMA transform + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Apply computeAt for TMA cache + tv0->computeAt(tv1, 2); + + // Merge TMA tile and Parallelize + // [M/bod, N/bid, bod, bid] -> [M/bod, N/bid, bod * bid] + reference->merge(-2, -1); + // [M/bod, N/bid, bod * bid] -> [M/bod, N/bid, (bod * bid) / 128, 128] + reference->split(-1, 128); + + // Parallelize + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + // Circular Buffer with TMA loads + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->axis(-2)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::ones({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Pointwise) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + // Load TV0 into shared memory + TensorView* tv4 = tv1->cacheAfter(); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + tv0->computeAt(tv2, 2); + tv1->computeAt(tv2, 2); + + // Ciruclar Buffer with TMA loads + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer(number_of_stages); + + // Ciruclar Buffer with set operation + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->circularBuffer(number_of_stages); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0, t1}); + + std::vector cg_outputs = fe.runFusion({t0, t1}); + compare(tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), t2); + testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Reduction) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + + TensorView* tv1 = sum(tv0, {-1}); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + constexpr int64_t examples_per_cta = 4; + constexpr int64_t bulk_inner_dim = 256; + + // [M, N] -> [M/epc, epc, N] + reference->split(0, examples_per_cta); + // [M/epc, epc, N] -> [M/epc, epc, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // [M/epc, epc, N/bid, bid] -> [M/epc, epc, N] + reference->merge(-2, -1); + // [M/epc, epc, N] -> [M/epc, epc, N/tdx, tdx] + constexpr int64_t tdx = 128; + reference->split(-1, tdx); + + // Parallelize + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + // InlineMost automatically handles vectorize and tma dimensions + inlineMost(); + + // Circular Buffer with TMA loads + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = sum(t0, {-1}); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_outer_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Persistent) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + constexpr at::ScalarType dtype = at::ScalarType::Float; + constexpr int64_t correction = 0; + constexpr int64_t reduction_axis = 1; + constexpr bool keepdim = true; + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* x = makeContigTensor(2, aten_to_data_type(dtype)); + fusion->addInput(x); + + // Algorithm: + // x_norm = (x - x_mean) / sqrt(x_var) + Val* num_elem = x->getLoopDomain().at(reduction_axis)->extent(); + + TensorView* sum_x = sum(x, {reduction_axis}, /*keepdim=*/false); + TensorView* mean_x = div(sum_x, num_elem); + TensorView* bcast_mean = broadcast(mean_x, {false, true}); + + TensorView* x_mean_sub = sub(x, bcast_mean); + TensorView* x_mean_sub_sq = mul(x_mean_sub, x_mean_sub); + TensorView* sum_x_mean_sub_sq = + sum(x_mean_sub_sq, {reduction_axis}, /*keepdim=*/false); + TensorView* var_x = div(sum_x_mean_sub_sq, num_elem); + TensorView* bcast_var = broadcast(var_x, {false, true}); + + TensorView* x_norm = div(sub(x, bcast_mean), sqrt(bcast_var)); + fusion->addOutput(x_norm); + + // Load input from global to shared memory + TensorView* x_cache_smem = + x->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + x_cache_smem->setMemoryType(MemoryType::Shared); + + // Load input from shared memory to registers + x_cache_smem->cacheAfter(); + + // Store results in registers + x_norm->cacheBefore(); + + std::vector reduction_tvs = + scheduler_utils::getReductionTvs(fusion.get()); + + TensorView* reference_tv = x_norm; + + // boxDim array must be non-zero and less than or equal to 256 + constexpr int64_t width = 32; + constexpr int64_t vectorize = 4; + int64_t elem_per_compute_thread = tensor_inner_dim / width / vectorize; + constexpr int64_t examples_per_cta = 4; + + // Since multi-dim CpAsyncBulk has a size limit of 256 per dimension, + // we require multiple TMA operations to load the entire example in shared + // memory for pointwise kernel. + // + // Define TMA Box + // logical domain: [I1, I2] + x_cache_smem->split(0, examples_per_cta); + // split: [I0 / 4, 4, I2] + x_cache_smem->split(-1, 256); + // split: [I0/4, 4, I2/256, 256] + + // Schedule reference_tv + // logical domain: [I1, I2] + // split: [I1, I2/V (width / tdx), V] + reference_tv->split(-1, vectorize); + // split: [I1, EPCT, I2/V/EPCT (tdx), V] + reference_tv->split(-2, elem_per_compute_thread, /*inner_split=*/false); + // split: [I1, EPCT, I2/V/EPCT (tdx), U, V] + reference_tv->split(-2, 1); + // reorder: [I1, I2/V/EPCT (tdx), EPCT, U, V] + reference_tv->reorder({{-4, -3}, {-3, -4}}); + // reorder: [I1/EPC, EPC, I2/V/EPCT (tdx), EPCT, U, V] + reference_tv->split(0, examples_per_cta); + + TransformPropagator propagator(reference_tv); + std::vector all_tvs_except_cache = + ir_utils::allTvsExcept(fusion.get(), {x_cache_smem}); + SetSelector selector( + {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) + .traverse(&propagator); + + std::vector rfactor_tvs; + rfactor_tvs.reserve(reduction_tvs.size()); + std::transform( + reduction_tvs.begin(), + reduction_tvs.end(), + std::back_inserter(rfactor_tvs), + [](TensorView* tv) { return tv->rFactor({-3, -2, -1}); }); + + // Define Parallelization Schema + reference_tv->axis(0)->parallelize(ParallelType::BIDx); + reference_tv->axis(2)->parallelize(ParallelType::TIDx); + reference_tv->axis(-2)->parallelize(ParallelType::Unroll); + scheduler_utils::parallelizeAllLike(reference_tv); + + // Vectorize Cache + reference_tv->axis(-1)->parallelize(ParallelType::Vectorize); + + // InlineMost automatically handles vectorize and tma dimensions + inlineMost(); + + // Handle TMA Tensor + // Apply circular buffer after computeAt + x_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + if (examples_per_cta > 1) { + x_cache_smem->circularBuffer(number_of_stages); + } + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + at::Tensor at_tv0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor at_tv1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + + // Compile with FusionExecutor directly to avoid scheduling + FusionExecutor fe; + fe.compileFusion(fusion.get(), {at_tv0}); + std::vector cg_outputs = fe.runFusion({at_tv0}); + + std::tuple at_var_mean = + at::var_mean(at_tv0, {-1}, correction, keepdim); + at::Tensor at_var = std::get<0>(at_var_mean); + at::Tensor at_mean = std::get<1>(at_var_mean); + at::Tensor at_output = (at_tv0 - at_mean) / sqrt(at_var); + + testValidate( + fusion.get(), cg_outputs, {at_tv0}, {at_output}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Matmul) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // Algorithm + TensorView* tv0 = makeContigTensor(2); // (M, K) + TensorView* tv1 = makeContigTensor(2); // (K, N) + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion->addOutput(tv5); + + // CpAsyncBulk Store + TensorView* tv6 = tv5->cacheBefore(LoadStoreOpType::CpAsyncBulkTensorTile); + tv6->setMemoryType(MemoryType::Shared); + + // For register circular buffering + TensorView* tv0_cache_local = tv0->cacheAfter(); + TensorView* tv1_cache_local = tv1->cacheAfter(); + + // For shared memory circular buffering + TensorView* tv0_cache_smem = + tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + TensorView* tv1_cache_smem = + tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv0_cache_smem->setMemoryType(MemoryType::Shared); + tv1_cache_smem->setMemoryType(MemoryType::Shared); + + constexpr int64_t BSX = 32; + constexpr int64_t TSX = 8; + + // Step 0: [M, K, N] + // Step 1: [M, K, N/BSX, BSX] + tv6->split(-1, BSX); + + // Step 2: [M, K, N/BSX, BSX/TSX, TSX] + tv6->split(-1, TSX); + + // Step 3: [M, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(1, BSX); + + // Step 4: [M/BSX, BSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(0, BSX); + + // Step 5:[M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(1, TSX); + + // Step 6: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + tv6->reorder( + {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); + + // Step 7a: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX (reduce)] + // Step 7b: [M/BSX, N/BSX, K/BSX (reduce), BSX/TSX, BSX/TSX, TSX, TSX] + TensorView* tv6_rf = tv6->rFactor({-1}); + + TransformPropagatorWithCheck propagator(tv6_rf); + MaxLogicalDomainInfoSpanningTree(tv6_rf).traverse(&propagator); + + // IterDomain: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + // Parallelization: BDX, BDY, K/BSX ||, BSX/TSX, BSX/TSX, TDY, TSX, TDX] + // 4 non-parallelized for-loops + tv0->computeAt(tv6, 3); + tv1->computeAt(tv6, 3); + + tv6_rf->computeAt(tv6, -1); + tv0_cache_local->computeAt(tv6_rf, -1); + tv1_cache_local->computeAt(tv6_rf, -1); + + // Parallelize + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-3)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv5); + + // (BSX/TSX * TSX * BSX) = 1024 floats = 4096 bytes * (number of buffers) + // Apply circular buffering to smem and local cache tensors + tv0_cache_smem->axis(-3)->parallelize(ParallelType::Bulk); + tv0_cache_smem->axis(-2)->parallelize(ParallelType::Bulk); + tv0_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + + tv1_cache_smem->axis(-3)->parallelize(ParallelType::Bulk); + tv1_cache_smem->axis(-2)->parallelize(ParallelType::Bulk); + tv1_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + + tv0_cache_local->circularBuffer(number_of_stages); + tv1_cache_local->circularBuffer(number_of_stages); + + tv0_cache_smem->circularBuffer(number_of_stages); + tv1_cache_smem->circularBuffer(number_of_stages); + + // Apply ParallelType::Bulk to global output tensor. + tv5->axis(-4)->parallelize(ParallelType::Bulk); + tv5->axis(-3)->parallelize(ParallelType::Bulk); + tv5->axis(-2)->parallelize(ParallelType::Bulk); + tv5->axis(-1)->parallelize(ParallelType::Bulk); + + constexpr int64_t K = 1024; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, K}, options); + at::Tensor t1 = at::randn({K, tensor_inner_dim}, options); + at::Tensor aten_output = + (t0.unsqueeze(/*dim=*/-1) * t1.unsqueeze(/*dim=*/0)).sum(/*dim=*/1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0, t1}); + + std::vector cg_outputs = fe.runFusion({t0, t1}); + compare( + tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), aten_output); + testValidate( + fusion.get(), cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); +} + +// Test circular buffer from 2 to 5 stages +INSTANTIATE_TEST_SUITE_P( + Hopper, + TmaCircularBufferingTest, + testing::Combine( + ::testing::Range(2, 5), + testing::Values(128, 500, 1024), + testing::Values(128, 1024))); + } // namespace nvfuser