diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 6c81629af0f..53c7879e8b0 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include @@ -23,11 +24,12 @@ namespace { // buffer tensor is on smem, in which case it would otherwise require // an additional predicate to guard buffer overruns. When it's on // gmem, that isn't the case, so it does not need to create an -// epilogue loop. +// epilogue loop. For TMA cpAsyncBulk, there is always an epilogue loop. bool requireEpilogue(const std::vector& exprs) { return std::any_of(exprs.begin(), exprs.end(), [](const Expr* expr) { - return expr->input(0)->as()->getMemoryType() == - MemoryType::Shared; + return (expr->input(0)->as()->getMemoryType() == + MemoryType::Shared) || + ir_utils::isCpAsyncBulk(expr); }); } @@ -215,6 +217,584 @@ class CircularBufferLoopCloner : public kir::IrVisitor { const std::unordered_set& exclude_; }; +// Description: +// Replicates circular buffer loops for Prologue, Main, and +// Epilogue. Prologue only copies the load expressions of circular +// buffered tensors, whereas Epilogue does any expression other than +// the loads. Main copies everything. The pre-prologue and post-epilogue loops +// are created separately by createCpAsyncBulkFixtures. +// +// Loop Structure Overview: +// Pre-prologue loop: +// - Allocate shared memory for mbarriers and mbarrier tokens +// - Initialize mbarrier for all stages +// +// Prologue loop: +// - if selected_thread: +// - Issue cp async bulks for all but last stage +// +// Main loop: +// - if selected_thread: +// - Issue next cp async bulk for available stage +// - All threads wait until tma operation arrives +// - Copy body without +// - shared memory allocations +// - mbarrier_init exprs +// - mbarrier_inval exprs +// +// Epilogue loop: +// - All threads wait until tma operation arrives +// - Copy body without +// - shared memory allocations +// - issuing cp async bulk operations +// - mbarrier_init exprs +// - mbarrier_inval exprs +// +// Post-epilogue loop: +// - if selected_thread: +// - Invalidated mbarrier for all stages +// +// Detailed Pseudo-Code: +// Pre-Prologue loop: +// +// - number_of_arrival_threads is the number of threads to call +// mbarrier::arrive or mbarrier::arriveExpectTx and to wait at +// mbarrier:wait. +// +// __shared__ __mbarrier_t barriers[num_stages]; +// __shared__ __mbarrier_token_t tokens[num_stages]; +// if (warp_id == 0 && electSync()()) { +// for (int64_t loop_index : irange(stages)) { +// int64_t number_of_arrive_threads = blockDim.x * blockDim.y * blockDim.z; +// mbarrier_init(mbarrier[loop_index], number_of_arrival_threads); +// } +// } +// +// Prologue loop: +// for (int64_t loop_index : irange(stages-1)) { +// if (warp_id == 0 && electSync()()) { +// tokens[loop_index] = +// mbarrier::arriveExpectTx(mbarrier[loop_index], expected_bytes); +// for (...) { +// cpAsyncBulk(mbarriers[loop_index], ...); +// } +// } else { +// tokens[loop_index] = +// mbarrier::arrive(mbarrier[loop_index]); +// } +// } +// +// Main loop: +// for (int64_t loop_index : irange(N-(stages-1))) { +// current_stage = loop_index % stage_depth +// load_stage = (loop_index + (stage_depth - 1)) % stage_depth) +// if (warp_id == 0 && electSync()()) { +// token[load_stage] = +// mbarrier::arriveExpectTx(mbarrier[load_stage], expected_bytes); +// for (...) { +// cpAsyncBulk(mbarrier[load_stage], ...); +// } +// } else { +// token[load_stage] = +// mbarrier::arrive(mbarrier[load_stage]); +// } +// mbarrier::wait(token[current_stage]); +// +// Clone remaining operations +// } +// +// Epilogue loop: +// for (int64_t loop_index : irange(N-(stages-1), N)) { +// current_stage = loop_index % stage_depth +// mbarrier::wait(token[current_stage]); +// +// Clone remaining operations +// } +// +// Post-Epilogue loop: +// if (warp_id == 0 && electSync()()) { +// for (int64_t loop_index : irange(stages)) { +// mbarrier_inval(mbarrier[loop_index]); +// } +// } +// +class CloneTmaCircularBufferLoopAndInsertSync + : public CircularBufferLoopCloner { + public: + static ForLoop* clone( + ForLoop* circular_buffer_loop, + const std::vector& circular_buffer_load_exprs, + CircularBufferLoopStage loop_type, + const std::unordered_set& exclude = {}) { + CloneTmaCircularBufferLoopAndInsertSync cloner( + circular_buffer_loop, circular_buffer_load_exprs, loop_type, exclude); + cloner.duplicate(); + return cloner.cloned_top_level_loop_; + } + + private: + CloneTmaCircularBufferLoopAndInsertSync( + ForLoop* circular_buffer_loop, + const std::vector& circular_buffer_load_exprs, + CircularBufferLoopStage loop_type, + const std::unordered_set& exclude) + : CircularBufferLoopCloner( + circular_buffer_loop, + circular_buffer_load_exprs, + loop_type, + exclude) {} + + // For TmaCircularBufferLoop, we have an mbarrier for each Tensorview and + // each circular buffer stage, but not for each individual TMA load + // operation. If there are serial IterDomains to the right of the computeAt + // position, nvfuser will generate a for-loop to launch multiple TMA load + // operations. This for-loop is passed to processForLoop as the cloned_loop + // argument. + // + // When we encounter a CpAsyncBulk load expression, we create a mbarrier_wait + // for the main and epilogue loops and a arriveExpectTx for prologue and main + // loops. handleMainLoop and handleEpilogLoop create mbarrier_wait expression. + // handleMainLoop and handlePrologLoop create mbarrier::arriveExpectTx + // expression. The expected_tx for arriveExpectTx is the cumulative + // transaction size for all TMA load operations for the TensorView. Next, we + // generate the nested for-loops for the serial IterDomains, but do not add + // them to the cloned circular buffer loop immediately. Once the cloned + // circular buffer loop is the only loop in the stack, add the arriveExpectTx + // and arrive expressions, then the nested for-loop structure calling the TMA + // load operations, and finally the mbarrier_wait. + void processForLoop(ForLoop* cloned_loop) final { + // Skip if there is not an active for-loop structure + if (for_loop_stack_.empty()) { + return; + } + + if (!cloned_loop->body().empty()) { + // mbarrier_arrive_tx_ is active when we encounter a cpAsyncBulk load + // operation on a circular buffer TensorView in IrVisitor. A single + // mbarrier_arrive_tx is active for each TensorView. + if (mbarrier_arrive_tx_ == nullptr || for_loop_stack_.size() > 1) { + // Add cloned for_loop when mbarrier_arrive_tx_ is not active or + // we are within a nested for-loop structure + for_loop_stack_.back()->body().push_back(cloned_loop); + } else { + // mbarrier::arriveExpectTx and TMA load operations occur in prologue + // and main loops. + // + // cloned_loop is nested for-loop containing cpAsyncBulk expressions. + // addTmaLoadBlock replaces the cloned_loop with: + // + // if(elect) { + // arriveExpectTx; + // for (...) { + // cpAsyncBulk; + // } + // } else { + // arrive; + // } + NVF_ERROR(for_loop_stack_.front() == cloned_top_level_loop_); + addTmaLoadBlock(cloned_loop); + } + } + + // mbarrier::wait occurs in Main and Epilogue loops. + if (mbarrier_wait_ != nullptr && for_loop_stack_.size() == 1) { + NVF_ERROR(for_loop_stack_.back() == cloned_top_level_loop_); + cloned_top_level_loop_->body().push_back(mbarrier_wait_); + mbarrier_wait_ = nullptr; + } + } + + void processExpr(Expr* expr) final { + // A mbarrier token exists if the TensorView output for cpAsyncBulk load has + // circular buffering depth > 1. ldstMBarrierTokenMap maps mbarrier_init, + // mbarrier_inval, and cpAsynBulk to the same mbarrier token. + bool mbarrier_token_exists = + GpuLower::current()->tmaCircularBufferInfo().existsMBarrierToken(expr); + + // Handle Short-Circuit conditions + switch (loop_type_) { + case CircularBufferLoopStage::Prolog: { + TensorView* out_tv = ir_utils::getTvOutput(expr); + bool is_circular_buffer_load_expr = std::any_of( + circular_buffer_load_exprs_.begin(), + circular_buffer_load_exprs_.end(), + [out_tv](Expr* load_expr) { + TensorView* circular_buffer_tv = ir_utils::getTvOutput(load_expr); + NVF_ERROR(circular_buffer_tv != nullptr); + return out_tv == circular_buffer_tv; + }); + // Short-circuit: skip expression if it is not circular buffer load + // expression. + if (!is_circular_buffer_load_expr) { + return; + } + + // Short-circuit: There can be circular buffered loads without + // cpAsyncBulk load expressions. + if (!mbarrier_token_exists) { + for_loop_stack_.back()->body().push_back(expr); + return; + } + break; + } + case CircularBufferLoopStage::Main: + case CircularBufferLoopStage::Epilog: { + // Short-circuit: Add expression if not circular-buffered load store + // operation. + if (!expr->isA() || !mbarrier_token_exists) { + for_loop_stack_.back()->body().push_back(expr); + return; + } + break; + } + case CircularBufferLoopStage::NotApplicable: { + NVF_ERROR(false, "Unsupported loop mode, got: ", loop_type_); + } + } + + // Handle cpAsyncBulk expression with circular buffered TensorView output. + switch (loop_type_) { + case CircularBufferLoopStage::Prolog: { + return handlePrologueLoop(expr); + } + case CircularBufferLoopStage::Main: { + return handleMainLoop(expr); + } + case CircularBufferLoopStage::Epilog: { + return handleEpilogLoop(expr); + } + case CircularBufferLoopStage::NotApplicable: { + NVF_ERROR(false, "Unsupported loop mode, got: ", loop_type_); + } + } + } + + // Replace cpAsyncBulk type LoadStoreOp with: + // if (warp_id == 0 && electSync()()) { + // tokens[loop_index] = + // mbarrier::arriveExpectTx(mbarrier[loop_index], expected_bytes); + // for (...) { + // cpAsyncBulk(mbarriers[loop_index], ...); + // } + // } else { + // tokens[loop_index] = mbarrier::arrive(mbarrier[loop_index]); + // } + // } + void handlePrologueLoop(Expr* expr) { + NVF_ERROR(expr != nullptr); + + // Skip if not LoadStoreOp expression + if (!expr->isA()) { + return; + } + + LoadStoreOp* ldst = expr->as(); + + // There should be a single mbarrier_arrive_tx_ for all ldst in current + // stage. + NVF_ERROR(mbarrier_arrive_tx_ == nullptr); + mbarrier_arrive_tx_ = createMbarrierArriveExpectTx( + ldst, cloned_top_level_loop_->indexOrStartIfTrivial()); + + // Clone LoadStoreOp and map it to mbarrier alloc + Expr* new_ldst = + IrBuilder::create( + ldst->opType(), ldst->out(), ldst->in(), ldst->cacheOp()) + ->withPredicate(ldst->predicate()); + + // Register mbarrier object to be used with new LoadStoreOp + // from prolog loop + NVF_ERROR(mbarrier_arrive_tx_->mbarrier()->isA()); + GpuLower::current()->tmaCircularBufferInfo().recordTensorIndex( + new_ldst, mbarrier_arrive_tx_->mbarrier()->as()); + + // If last cloned scope is the cloned_top_level_loop body, then add + // mbarrier::arriveExpectTx and new loadStoreOp. + int64_t active_for_loops = std::count_if( + for_loop_stack_.begin(), for_loop_stack_.end(), [](ForLoop* fl) { + return fl->iter_domain()->getParallelType() == ParallelType::Serial; + }); + if (active_for_loops == 1) { + return addTmaLoadBlock(new_ldst); + } + + // Otherwise, we are in a nested for-loop and should wait until we + // return to top-level for loop. + for_loop_stack_.back()->body().push_back(new_ldst); + } + + // Handle cpAsyncBulk type LoadStoreOp that is registered with token + // + // compute_stage = loop_index % stage_depth + // load_stage = (loop_index + (stage_depth - 1)) % stage_depth) + // + // Replace LoadStoreOp with: + // if (warp_id == 0 && electSync()()) { + // token[load_stage] = + // mbarrier::arriveExpectTx(mbarrier[load_stage], expected_bytes); + // for (...) { + // cpAsyncBulk(mbarrier[load_stage], ...); + // } + // } else { + // token[load_stage] = mbarrier::arrive(mbarrier[load_stage]); + // } + // mbarrier::wait(token[current_stage]); + // + // Where mbarrier and token are shared memory arrays bound to the + // LoadStoreOp + void handleMainLoop(Expr* expr) { + NVF_ERROR(expr != nullptr && expr->isA()); + + int64_t stage_depth = + GpuLower::current()->circularBufferInfo().getStageDepthFor( + circular_buffer_loop_->iter_domain()); + + if (current_compute_stage_ == nullptr) { + current_compute_stage_ = IrBuilder::modExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::create(stage_depth, PrimDataType::Index)); + kir::Allocate* current_compute_stage_alloc = + IrBuilder::create( + current_compute_stage_, + MemoryType::Local, + IrBuilder::create(1L, PrimDataType::Index), + /*zero_init=*/false); + cloned_top_level_loop_->body().push_back(current_compute_stage_alloc); + cloned_top_level_loop_->body().push_back( + current_compute_stage_->definition()); + } + + if (current_load_stage_ == nullptr) { + current_load_stage_ = IrBuilder::modExpr( + IrBuilder::addExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::subExpr( + IrBuilder::create(stage_depth, PrimDataType::Index), + IrBuilder::create(1L, PrimDataType::Index))), + IrBuilder::create(stage_depth, PrimDataType::Index)); + kir::Allocate* current_load_stage_alloc = + IrBuilder::create( + current_load_stage_, + MemoryType::Local, + IrBuilder::create(1L, PrimDataType::Index), + /*zero_init=*/false); + cloned_top_level_loop_->body().push_back(current_load_stage_alloc); + cloned_top_level_loop_->body().push_back( + current_load_stage_->definition()); + } + + LoadStoreOp* ldst = expr->as(); + + // There is a single mbarrier_arrive_tx_ for each cpAsyncBulk load + // expression. A mbarrier_arrive_tx_ for another cpAsyncBulk load expression + // should not be active. + NVF_ERROR(mbarrier_arrive_tx_ == nullptr); + mbarrier_arrive_tx_ = + createMbarrierArriveExpectTx(ldst, current_load_stage_); + + // Register mbarrier object to be used with LoadStoreOp + // from main loop + NVF_ERROR(mbarrier_arrive_tx_->mbarrier()->isA()); + GpuLower::current()->tmaCircularBufferInfo().recordTensorIndex( + ldst, mbarrier_arrive_tx_->mbarrier()->as()); + + // Construct mBarrier::wait for current stage + NVF_ERROR( + mbarrier_wait_ == nullptr, + "Expected mbarrier_wait to inactive for current TMA operation"); + mbarrier_wait_ = createMbarrierWait(ldst, current_compute_stage_); + + // If last cloned scope is the cloned_top_level_loop body, then add + // mbarrier::arriveExpectTx, new loadStoreOp, and mbarrier_wait + int64_t active_for_loops = std::count_if( + for_loop_stack_.begin(), for_loop_stack_.end(), [](ForLoop* fl) { + return fl->iter_domain()->getParallelType() == ParallelType::Serial; + }); + if (active_for_loops == 1) { + addTmaLoadBlock(ldst); + NVF_ERROR(mbarrier_wait_ != nullptr); + for_loop_stack_.back()->body().push_back(mbarrier_wait_); + mbarrier_wait_ = nullptr; + return; + } + + // Otherwise, we are in a nested for-loop and should wait until we + // return to top-level for loop. + for_loop_stack_.back()->body().push_back(ldst); + } + + void handleEpilogLoop(Expr* expr) { + NVF_ERROR(expr != nullptr && expr->isA()); + + // Construct mBarrier::wait for epilogue + LoadStoreOp* ldst = expr->as(); + int64_t stage_depth = + GpuLower::current()->circularBufferInfo().getStageDepthFor( + circular_buffer_loop_->iter_domain()); + Val* epilogue_compute_stage = IrBuilder::modExpr( + cloned_top_level_loop_->indexOrStartIfTrivial(), + IrBuilder::create(stage_depth, PrimDataType::Index)); + + NVF_ERROR( + mbarrier_wait_ == nullptr, + "Expected mbarrier_wait to inactive for current TMA operation"); + kir::MBarrierWait* mbarrier_wait = + createMbarrierWait(ldst, epilogue_compute_stage); + for_loop_stack_.back()->body().push_back(mbarrier_wait); + } + + // This function selects a single thread to launch tma load and mbarrier + // arrive_expected_tx operations. The remaining threads will simply arrive + // at the mbarrier. + // + // Pseudo-code example: + // if (warp_id == 0 && electSync()()) { + // tokens[next_stage] = + // mbarrier::arriveExpectTx(mbarrier[next_stage], + // expected_bytes); + // for (...) { + // cpAsyncBulk(mbarrier[next_stage], ...); + // } + // } else { + // tokens[next_stage] = mbarrier::arrive(mbarrier[next_stage]); + // } + // + // The expr input argument can be a single cpAsyncBulk expression or a nested + // for-loop structure of cpAsyncBulk expressions if there are serial + // iterDomains to the right of the computeAt position. + void addTmaLoadBlock(Expr* expr) { + NVF_ERROR(mbarrier_arrive_tx_ != nullptr); + NVF_ERROR(expr != nullptr); + + // Create the if-then-else with electSync() predicate for the arrive expect + // transaction. + kir::IfThenElse* if_expr = IrBuilder::create( + IrBuilder::create(PredicateType::ElectSync)); + + // A single thread issues arriveExpectTx with expected transactions and + // launches the TMA load. + if_expr->thenBody().push_back(mbarrier_arrive_tx_); + if_expr->thenBody().push_back(expr); + + // The other threads issue arriveExpectTx without any expected transactions. + kir::MBarrierArrive* thread_arrive = IrBuilder::create( + mbarrier_arrive_tx_->state(), mbarrier_arrive_tx_->mbarrier()); + if_expr->elseBody().push_back(thread_arrive); + for_loop_stack_.back()->body().push_back(if_expr); + + mbarrier_arrive_tx_ = nullptr; + } + + // Get size of tma load in bytes. It is used for expected transaction count in + // kir::MBarrierArriveExpectTx. + Val* getSizeOfTmaLoad(LoadStoreOp* ldst) { + NVF_ERROR(ldst != nullptr); + + TensorView* consumer_tv = ldst->out()->as(); + NVF_ERROR( + GpuLower::current()->consumerToTMAInfo().count(consumer_tv), + "Unable to find TMA info for consumer_tv: ", + consumer_tv->toString()); + + // Get expected bytes for given TMA load operation. + const TMAInfo& tma_info = + GpuLower::current()->consumerToTMAInfo().at(consumer_tv); + Val* expected_bytes = tma_info.tileSizeBytes(); + + // The expected_bytes for mbarrier::arriveExpectTX must account for all TMA + // load operations launched for each circular buffer stage. We take the + // product of all coordinate TMA iterDomains to the right of the circular + // buffer axis. + const std::vector& loop_domain = consumer_tv->getLoopDomain(); + const IdModel& id_model = GpuLower::current()->idModel(); + for (size_t idx = consumer_tv->getComputeAtPosition(); + idx < loop_domain.size(); + ++idx) { + IterDomain* id = + indexing_utils::getLoopPromotion(loop_domain.at(idx), id_model); + if (!isParallelTypeThread(id->getParallelType()) && + id->getParallelType() != ParallelType::Bulk) { + expected_bytes = + SimplifyingIrBuilder::mulExpr(expected_bytes, id->extent()); + } + } + expected_bytes = + SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, expected_bytes); + return expected_bytes; + } + + // This function creates kir::MBarrierArriveExpectTx for given LoadStoreOp and + // circular buffer stage. + // + // Example: + // tokens[stage] = + // mbarrier::arriveExpectTX(toSmem((&barriers[stage])), + // getSizeOfTmaLoad(ldst)); + kir::MBarrierArriveExpectTx* createMbarrierArriveExpectTx( + LoadStoreOp* ldst, + Val* loop_index) { + NVF_ERROR(ldst != nullptr); + NVF_ERROR(loop_index != nullptr); + + loop_index = GpuLower::current()->commonScalarMap().hoistScalar( + loop_index, for_loop_stack_); + + // Get mbarrier for this circular buffer stage. + TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + kir::TensorIndex* stage_mbarrier = + IrBuilder::create(all_mbarriers, loop_index); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarrier_tokens = + GpuLower::current()->tmaCircularBufferInfo().getMBarrierToken(ldst); + kir::TensorIndex* stage_token = + IrBuilder::create(all_mbarrier_tokens, loop_index); + + Val* tx_count = GpuLower::current()->commonScalarMap().hoistScalar( + getSizeOfTmaLoad(ldst), for_loop_stack_); + kir::MBarrierArriveExpectTx* mbarrier_arrive_tx = + IrBuilder::create( + stage_token, stage_mbarrier, tx_count); + + return mbarrier_arrive_tx; + } + + // This function creates kir::MBarrierWait for given LoadStoreOp and circular + // buffer stage. + kir::MBarrierWait* createMbarrierWait(LoadStoreOp* ldst, Val* loop_index) { + NVF_ERROR(ldst != nullptr); + NVF_ERROR(loop_index != nullptr); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + kir::TensorIndex* stage_mbarrier = + IrBuilder::create(all_mbarriers, loop_index); + + // Get mbarrier_token for this circular buffer stage. + TensorView* all_mbarrier_tokens = + GpuLower::current()->tmaCircularBufferInfo().getMBarrierToken(ldst); + kir::TensorIndex* stage_token = + IrBuilder::create(all_mbarrier_tokens, loop_index); + + kir::MBarrierWait* mbarrier_wait = + IrBuilder::create(stage_mbarrier, stage_token); + return mbarrier_wait; + } + + private: + // Mbarrier_Wait to add to cloned_top_level_loop + kir::MBarrierWait* mbarrier_wait_ = nullptr; + + // Mbarrier_ArriveExpectTx to add to cloned_top_level_loop + kir::MBarrierArriveExpectTx* mbarrier_arrive_tx_ = nullptr; + + // current_stage_index = (loop_index % stages) + Val* current_compute_stage_ = nullptr; + + // next_stage_index = (loop_index + (stages-1)) % stages + Val* current_load_stage_ = nullptr; +}; + using InsertionInfo = std::unordered_map>; class IsCircularBufferLoadLoop : public kir::IrVisitor { @@ -387,11 +967,51 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - insert(loop, it->second); + auto hasCpAsyncBulk = std::any_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); + + if (hasCpAsyncBulk) { + insertTma(loop, it->second); + } else { + insert(loop, it->second); + } processed_loop_ = loop; insertion_info_.erase(loop); } + void insertTma( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + // Prologue loop: + // - launch only + // - arrive_expect_tx and tma load operations + ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::Prolog); + registerInsertBefore(circular_buffer_loop, prologue_loop); + + // Main loop: + // - Launch and wait + // - arrive_expect_tx, tma load operations, and mbarrier_wait) + ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::Main); + registerReplace(circular_buffer_loop, main_loop); + + // We can use exclude argument in CloneTmaCircularBufferLoopAndInsertSync + // clone to avoid duplicating allocations if main loop is trivial. + std::unordered_set expressions_allocated_in_main_loop; + getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop); + + // Epilogue loop: + // - wait only + // - mbarrier_wait + ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Epilog, + expressions_allocated_in_main_loop); + registerInsertAfter(circular_buffer_loop, epilogue_loop); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { ForLoop* prologue_loop = CircularBufferLoopCloner::clone( circular_buffer_loop, loads, CircularBufferLoopStage::Prolog); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 26c94568be6..9c63d213027 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1441,6 +1441,17 @@ void IndexLowering::handle(const kir::MBarrierInvalidate* minval) { GpuLower::current()->propagateExprInfo(minval, minval_indexed); } +void IndexLowering::handle(const kir::MBarrierArrive* arrive_transaction) { + NVF_ERROR( + arrive_transaction->mbarrier()->isA(), + "Expected kir::TensorIndex in MBarrierArriveExpectTx"); + + Val* smem_address_ptr = lower_utils::u32IndexScalarSmemTv( + arrive_transaction->mbarrier()->as()); + pushBack(IrBuilder::create( + arrive_transaction->state(), smem_address_ptr)); +} + void IndexLowering::handle( const kir::MBarrierArriveExpectTx* arrive_transaction) { NVF_ERROR( diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 41d71726beb..d14809e5ede 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -74,6 +74,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::GridSync*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; + void handle(const kir::MBarrierArrive*) final; void handle(const kir::MBarrierArriveExpectTx*) final; void handle(const kir::MBarrierWait*) final; void handle(const kir::AsyncWait*) final; diff --git a/csrc/device_lower/pass/unroll.cpp b/csrc/device_lower/pass/unroll.cpp index f12847a9cf8..3ffed3a8dea 100644 --- a/csrc/device_lower/pass/unroll.cpp +++ b/csrc/device_lower/pass/unroll.cpp @@ -42,6 +42,27 @@ void UnrollPass::registerReplace(Expr* reference, Expr* new_expr) { } void UnrollPass::dispatch(Expr* expr) { + // short-circuit: skip adding predicate if tma load with circular buffering or + // stand-alone arrive_expect_tx. + bool is_arrive_expect_tx = expr->isA(); + bool is_circular_buffer_tma_load = ir_utils::isCpAsyncBulkLoad(expr) && + expr->output(0)->as()->isCircularBuffered(); + if (is_arrive_expect_tx || is_circular_buffer_tma_load) { + return; + } + + // short-circuit: mbarrier_init or mbarrier_inval with elect sync predicate. + // predicate is specified for tma load with circular buffering. + bool is_mbarrier_init = expr->isA(); + bool is_mbarrier_inval = expr->isA(); + if ((is_mbarrier_init || is_mbarrier_inval) && expr->predicate() != nullptr) { + kir::IfThenElse* inline_ite = + IrBuilder::create(expr->predicate()); + kir::ExprMutator::registerReplace(expr, inline_ite); + inline_ite->thenBody().push_back(expr); + return; + } + if (ir_utils::isTvOp(expr)) { DEBUG_PRINT_SCOPE_NAME("UnrollPass::dispatch", expr); // If tv op, predicate it diff --git a/csrc/device_lower/pass/warp_reduce.cpp b/csrc/device_lower/pass/warp_reduce.cpp index a747af888e7..02e40e19e96 100644 --- a/csrc/device_lower/pass/warp_reduce.cpp +++ b/csrc/device_lower/pass/warp_reduce.cpp @@ -193,7 +193,22 @@ class FuseBroadcastWithWarpReduce : private kir::IrVisitor { kir::IrVisitor::dispatch(expr); } - bool openLoopNestLevel(IterDomain* id) { + bool openLoopNestLevel(ForLoop* fl) { + // circular buffering duplicates for-loops. Depending on the number of + // iterations in for-loop and size of circular buffering pipeline, either + // the main loop or epilogue loops can be trivial. In this case, we do not + // open another loop nest level. Allocations are added directly to the + // previous level. + bool is_main_loop = + fl->circularBufferLoopStage() == CircularBufferLoopStage::Main; + bool is_epilogue_loop = + fl->circularBufferLoopStage() == CircularBufferLoopStage::Epilog; + + if ((is_main_loop || is_epilogue_loop) && fl->isTrivial()) { + return false; + } + + IterDomain* id = fl->iter_domain(); if (id->isThread() || id->getParallelType() == ParallelType::Unswitch) { return false; } @@ -206,7 +221,7 @@ class FuseBroadcastWithWarpReduce : private kir::IrVisitor { void handle(ForLoop* for_loop) final { // Keep track of visible reduction outputs - bool open_nest_level = openLoopNestLevel(for_loop->iter_domain()); + bool open_nest_level = openLoopNestLevel(for_loop); if (open_nest_level) { running_tv_to_allocate_map_.emplace_back( std::make_unique>()); diff --git a/doc/dev/tma.md b/doc/dev/tma.md index b4fc12b31c7..1a4286da288 100644 --- a/doc/dev/tma.md +++ b/doc/dev/tma.md @@ -353,6 +353,7 @@ the TMA domain can be completely inferred from the schedule. > When using circular buffering with TMA, a single thread is select to launch the TMA load and mbarrier operations. > In this case, we cannot apply any block parallelization to the consumer TensorView, which will create a thread predicate. > A compile-time error will occur if you apply circular buffering and block parallelization together. +> See `TEST_F(NVFuserTest, ElectSyncCompatibility)` for an example. #### Data swizzle diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 0861b227cac..772b16f4ca9 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -6,9 +6,12 @@ */ // clang-format on +#include #include +#include #include #include +#include namespace nvfuser { @@ -802,4 +805,829 @@ INSTANTIATE_TEST_SUITE_P( CircularBufferingTest, ::testing::Range(2, 10)); +using TmaCircularBufferingParams = std::tuple; + +class TmaCircularBufferingTest + : public NVFuserFixtureParamTest { + protected: + int64_t number_of_stages = 1; + int64_t tensor_outer_dim = 1; + int64_t tensor_inner_dim = 1; + + void SetUp() override { + number_of_stages = std::get<0>(GetParam()); + tensor_outer_dim = std::get<1>(GetParam()); + tensor_inner_dim = std::get<2>(GetParam()); + + // NOTE: Multiple of 16 required for inner dimension + NVF_ERROR(tensor_inner_dim % 16 == 0); + NVFuserTest::SetUp(); + } + + template + void compare(int64_t tensor_dim, at::Tensor result, at::Tensor reference) { + at::Tensor reference_cpu_data = reference.cpu(); + at::Tensor result_cpu_data = result.cpu(); + + auto reference_cpu = reference_cpu_data.accessor(); + auto result_cpu = result_cpu_data.accessor(); + + constexpr double tolerance = 1e-3; + for (int64_t pos = 0; pos < tensor_dim; ++pos) { + if (fabs((double)result_cpu[pos] - (double)reference_cpu[pos]) > + tolerance) { + std::cout << "[" << pos << "] - result: " << result_cpu[pos] + << " | reference: " << reference_cpu[pos] << std::endl; + } + } + } + + template + void compare( + int64_t tensor_outer_dim, + int64_t tensor_inner_dim, + at::Tensor result, + at::Tensor reference) { + at::Tensor reference_cpu_data = reference.cpu(); + at::Tensor result_cpu_data = result.cpu(); + + auto reference_cpu = reference_cpu_data.accessor(); + auto result_cpu = result_cpu_data.accessor(); + + constexpr double tolerance = 1e-3; + for (int64_t out_pos = 0; out_pos < tensor_outer_dim; ++out_pos) { + for (int64_t in_pos = 0; in_pos < tensor_inner_dim; ++in_pos) { + if (fabs( + (double)reference_cpu[out_pos][in_pos] - + (double)result_cpu[out_pos][in_pos]) > tolerance) { + std::cout << "[" << out_pos << ", " << in_pos + << "] - result: " << result_cpu[out_pos][in_pos] + << " | ref: " << reference_cpu[out_pos][in_pos] + << std::endl; + } + } + } + } +}; + +TEST_F(NVFuserTest, ElectSyncCompatibility) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* input = makeContigTensor(3); + fusion->addInput(input); + TensorView* output = set(input); + fusion->addOutput(output); + + TensorView* smem_cache = + input->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + smem_cache->setMemoryType(MemoryType::Shared); + + // For TMA load, both the shared memory layout and the loop nest and + // parallelization of TMA are specified by the consumer: smem_cache + + // Step 1: define TMA domain + // Because we want to treat the entire tensor as 1D, we define the TMA + // domain as [I0*I1*I2] + smem_cache->merge(0); + smem_cache->merge(0); + // Note that the TMA domain only exist in people's mind, there is no need to + // set anything here. + + // Step 2: define box + smem_cache->split(0, 256); + // [I0*I1*I2/256, 256] + // partitioned IterDomain: I0*I1*I2 + // coordinate IterDomain: I0*I1*I2/256 + // box IterDomain: 256 + + // Step 3: define tile + // We use dense tile here, so tile == box. Nothing to do here. + + // Step 4: schedule the shared memory tensor + // By default, the allocation domain is the logical domain, which is already + // in good shape for this case. + + constexpr int64_t number_of_stages = 2; + // Step 5: schedule the consumer tensor + smem_cache->split(0, 4); + // [I0*I1*I2/256/4, 4, 256] + smem_cache->split(0, number_of_stages); + // [I0*I1*I2/256/4/2, 2, 4, 256] + + // [BIDx, 2, TIDx, Bulk] + smem_cache->axis(0)->parallelize(ParallelType::BIDx); + smem_cache->axis(2)->parallelize(ParallelType::TIDx); + smem_cache->axis(3)->parallelize(ParallelType::Bulk); + + // Schedule the smem->gmem part + output->merge(0); + output->merge(0); + output->split(0, 256); + output->split(0, 4); + output->split(0, number_of_stages); + output->axis(0)->parallelize(ParallelType::BIDx); + output->axis(3)->parallelize(ParallelType::TIDx); + + inlineAllAt(output, /*pos=*/2); + smem_cache->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + std::vector shape(3, 300); + auto t0 = at::randn(shape, options); + + // IterDomain 2 for the TMA load is parallelized with TIDx, so we generate + // (threadIdx.x < 4) predicate. This thread predicate is incompatible with + // circular buffering because we generate an ElectSync predicate that uses + // a single thread. + FusionExecutor fe; + try { + fe.compileFusion(fusion.get(), {t0}); + } catch (const std::exception& e) { + const char* reference = + R"(This thread-parallelized TensorView T2_s_float[ iblockIdx.x15{( ceilDiv(( ceilDiv(( ceilDiv(( ( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_size ))[1] ) ) * ( (( (( getMetaData(T0) )).logical_size ))[2] ) ), 256) ), 4) ), 2) )}, iS16{2}, ithreadIdx.x14{4}, iB12{256} ] ca_pos( 2 ) is incorrectly contained within a If-Then-Else with the ElectSync predicate.)"; + const char* str_match_pointer = strstr(e.what(), reference); + ASSERT_TRUE(str_match_pointer != nullptr); + } +} + +TEST_P(TmaCircularBufferingTest, SingleDim) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set inlineAt before applying circular buffer + inlineAllAt(tv1, /*pos=*/1); + + // Parallelization + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + // Circular Buffer with TMA loads + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, SingleDimUnroll) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t unroll_dim = 4; + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + // [M/bid, bid] -> [M/bid/unroll, unroll, bid] + reference->split(0, unroll_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set ComputeAt position + inlineAllAt(tv1, /*pos=*/1); + + // Apply Unroll + tv1->axis(1)->parallelize(ParallelType::Unroll); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + // Circular Buffer with TMA loads + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + int64_t axis_extent = + ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim); + if (axis_extent < number_of_stages) { + ASSERT_ANY_THROW(fe.runFusion({t0})); + return; + } + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, SingleDimUnswitch) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(1); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr size_t unroll_dim = 4; + constexpr size_t bulk_inner_dim = 32; + + // [M] -> [M/bid, bid] + reference->split(-1, bulk_inner_dim); + // [M/bid, bid] -> [M/bid/unroll, unroll, bid] + reference->split(0, unroll_dim); + + // Propagate Transformations + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set ComputeAt position + inlineAllAt(tv1, /*pos=*/1); + + // Apply Unswitch + tv1->axis(1)->parallelize(ParallelType::Unswitch); + tv1->axis(-1)->parallelize(ParallelType::TIDx); + + // Circular Buffer with TMA loads + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + int64_t axis_extent = + ceilDiv(ceilDiv(tensor_inner_dim, bulk_inner_dim), unroll_dim); + if (axis_extent < number_of_stages) { + ASSERT_ANY_THROW(fe.runFusion({t0})); + return; + } + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, MultiDim) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + + TensorView* tv1 = exp(tv0); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + // Constants + constexpr int64_t tma_outer_dim = 4; + constexpr int64_t tma_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, tma_inner_dim); + // [M, N/bid, bid] -> [M/bod, bod, N/bid, bid] + reference->split(0, tma_outer_dim); + // [M/bod, bod, N/bid, bid] -> [M/bod, N/bid, bod, bid] + reference->reorder({{-2, -3}}); + + // Propagate TMA transform + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Apply inlineAt for TMA cache + inlineAllAt(tv1, /*pos=*/2); + + // Merge TMA tile and Parallelize + // [M/bod, N/bid, bod, bid] -> [M/bod, N/bid, bod * bid] + reference->merge(-2, -1); + // [M/bod, N/bid, bod * bid] -> [M/bod, N/bid, (bod * bid) / 128, 128] + reference->split(-1, 128); + + // Parallelize + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + // Circular Buffer with TMA loads + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->axis(-2)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::ones({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::exp(t0); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Pointwise) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + // Load TV1 into shared memory + TensorView* tv4 = tv1->cacheAfter(); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer(number_of_stages); + + // Circular Buffer with set operation + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->circularBuffer(number_of_stages); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0, t1}); + + std::vector cg_outputs = fe.runFusion({t0, t1}); + compare(tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), t2); + testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = add(tv0, tv1); + fusion->addOutput(tv2); + + // Use TMA to load TV0 into shared memory + TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv3->setMemoryType(MemoryType::Shared); + + // Load TV1 into shared memory + TensorView* tv4 = tv1->cacheAfter(LoadStoreOpType::CpAsync); + tv4->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv2; + + // Constants + constexpr int64_t bulk_inner_dim = 32; + + // [M, N] -> [M, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + + // Circular Buffer with TMA loads + tv3->axis(0)->parallelize(ParallelType::BIDx); + tv3->axis(2)->parallelize(ParallelType::Bulk); + tv3->circularBuffer(number_of_stages); + + // Circular Buffer with set operation + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->circularBuffer(number_of_stages); + + // Split reference to parallelize TMA tile + reference->split(-1, 32); + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t2 = t0 + t1; + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0, t1}); + + std::vector cg_outputs = fe.runFusion({t0, t1}); + // TODO enable when test passes + // compare(tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), t2); + + // Expect failure because of missing predicate support for cpAsync loads. + // See https://github.com/NVIDIA/Fuser/pull/2339 + ASSERT_ANY_THROW(testValidate( + fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__)); +} + +TEST_P(TmaCircularBufferingTest, Reduction) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + + TensorView* tv1 = sum(tv0, {-1}); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + constexpr int64_t examples_per_cta = 4; + constexpr int64_t bulk_inner_dim = 256; + + // [M, N] -> [M/epc, epc, N] + reference->split(0, examples_per_cta); + // [M/epc, epc, N] -> [M/epc, epc, N/bid, bid] + reference->split(-1, bulk_inner_dim); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // [M/epc, epc, N/bid, bid] -> [M/epc, epc, N] + reference->merge(-2, -1); + // [M/epc, epc, N] -> [M/epc, epc, N/tdx, tdx] + constexpr int64_t tdx = 128; + reference->split(-1, tdx); + + // Parallelize + reference->axis(0)->parallelize(ParallelType::BIDx); + reference->axis(-1)->parallelize(ParallelType::TIDx); + + // InlineMost automatically handles vectorize and tma dimensions + inlineMost(); + + // Circular Buffer with TMA loads + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(-1)->parallelize(ParallelType::Bulk); + tv2->circularBuffer(number_of_stages); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = sum(t0, {-1}); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0}); + + std::vector cg_outputs = fe.runFusion({t0}); + compare(tensor_outer_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Persistent) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + constexpr at::ScalarType dtype = at::ScalarType::Float; + constexpr int64_t correction = 0; + constexpr int64_t reduction_axis = 1; + constexpr bool keepdim = true; + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* x = makeContigTensor(2, aten_to_data_type(dtype)); + fusion->addInput(x); + + // Algorithm: + // x_norm = (x - x_mean) / sqrt(x_var) + Val* num_elem = x->getLoopDomain().at(reduction_axis)->extent(); + + TensorView* sum_x = sum(x, {reduction_axis}, /*keepdim=*/false); + TensorView* mean_x = div(sum_x, num_elem); + TensorView* bcast_mean = broadcast(mean_x, {false, true}); + + TensorView* x_mean_sub = sub(x, bcast_mean); + TensorView* x_mean_sub_sq = mul(x_mean_sub, x_mean_sub); + TensorView* sum_x_mean_sub_sq = + sum(x_mean_sub_sq, {reduction_axis}, /*keepdim=*/false); + TensorView* var_x = div(sum_x_mean_sub_sq, num_elem); + TensorView* bcast_var = broadcast(var_x, {false, true}); + + TensorView* x_norm = div(sub(x, bcast_mean), sqrt(bcast_var)); + fusion->addOutput(x_norm); + + // Load input from global to shared memory + TensorView* x_cache_smem = + x->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + x_cache_smem->setMemoryType(MemoryType::Shared); + + // Load input from shared memory to registers + x_cache_smem->cacheAfter(); + + // Store results in registers + x_norm->cacheBefore(); + + std::vector reduction_tvs = + scheduler_utils::getReductionTvs(fusion.get()); + + TensorView* reference_tv = x_norm; + + // boxDim array must be non-zero and less than or equal to 256 + constexpr int64_t width = 32; + constexpr int64_t vectorize = 4; + int64_t elem_per_compute_thread = tensor_inner_dim / width / vectorize; + constexpr int64_t examples_per_cta = 4; + + // Since multi-dim CpAsyncBulk has a size limit of 256 per dimension, + // we require multiple TMA operations to load the entire example in shared + // memory for pointwise kernel. + // + // Define TMA Box + // logical domain: [I1, I2] + x_cache_smem->split(0, examples_per_cta); + // split: [I0 / 4, 4, I2] + x_cache_smem->split(-1, 256); + // split: [I0/4, 4, I2/256, 256] + + // Schedule reference_tv + // logical domain: [I1, I2] + // split: [I1, I2/V (width / tdx), V] + reference_tv->split(-1, vectorize); + // split: [I1, EPCT, I2/V/EPCT (tdx), V] + reference_tv->split(-2, elem_per_compute_thread, /*inner_split=*/false); + // split: [I1, EPCT, I2/V/EPCT (tdx), U, V] + reference_tv->split(-2, 1); + // reorder: [I1, I2/V/EPCT (tdx), EPCT, U, V] + reference_tv->reorder({{-4, -3}, {-3, -4}}); + // reorder: [I1/EPC, EPC, I2/V/EPCT (tdx), EPCT, U, V] + reference_tv->split(0, examples_per_cta); + + TransformPropagator propagator(reference_tv); + std::vector all_tvs_except_cache = + ir_utils::allTvsExcept(fusion.get(), {x_cache_smem}); + SetSelector selector( + {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) + .traverse(&propagator); + + std::vector rfactor_tvs; + rfactor_tvs.reserve(reduction_tvs.size()); + std::transform( + reduction_tvs.begin(), + reduction_tvs.end(), + std::back_inserter(rfactor_tvs), + [](TensorView* tv) { return tv->rFactor({-3, -2, -1}); }); + + // Define Parallelization Schema + reference_tv->axis(0)->parallelize(ParallelType::BIDx); + reference_tv->axis(2)->parallelize(ParallelType::TIDx); + reference_tv->axis(-2)->parallelize(ParallelType::Unroll); + scheduler_utils::parallelizeAllLike(reference_tv); + + // Vectorize Cache + reference_tv->axis(-1)->parallelize(ParallelType::Vectorize); + + // InlineMost automatically handles vectorize and tma dimensions + inlineMost(); + + // Handle TMA Tensor + // Apply circular buffer after computeAt + x_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + if (examples_per_cta > 1) { + x_cache_smem->circularBuffer(number_of_stages); + } + + auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA, 0); + at::Tensor at_tv0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor at_tv1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + + // Compile with FusionExecutor directly to avoid scheduling + FusionExecutor fe; + fe.compileFusion(fusion.get(), {at_tv0}); + std::vector cg_outputs = fe.runFusion({at_tv0}); + + std::tuple at_var_mean = + at::var_mean(at_tv0, {-1}, correction, keepdim); + at::Tensor at_var = std::get<0>(at_var_mean); + at::Tensor at_mean = std::get<1>(at_var_mean); + at::Tensor at_output = (at_tv0 - at_mean) / sqrt(at_var); + + testValidate( + fusion.get(), cg_outputs, {at_tv0}, {at_output}, __LINE__, __FILE__); +} + +TEST_P(TmaCircularBufferingTest, Matmul) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // Algorithm + TensorView* tv0 = makeContigTensor(2); // (M, K) + TensorView* tv1 = makeContigTensor(2); // (K, N) + fusion->addInput(tv0); + fusion->addInput(tv1); + + TensorView* tv2 = broadcast(tv0, {false, false, true}); // (M, K, B) + TensorView* tv3 = broadcast(tv1, {true, false, false}); // (B, K, N) + TensorView* tv4 = mul(tv2, tv3); // M, K, N + TensorView* tv5 = sum(tv4, {1}); // M, R, N + fusion->addOutput(tv5); + + // CpAsyncBulk Store + TensorView* tv6 = tv5->cacheBefore(LoadStoreOpType::CpAsyncBulkTensorTile); + tv6->setMemoryType(MemoryType::Shared); + + // For register circular buffering + TensorView* tv0_cache_local = tv0->cacheAfter(); + TensorView* tv1_cache_local = tv1->cacheAfter(); + + // For shared memory circular buffering + TensorView* tv0_cache_smem = + tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + TensorView* tv1_cache_smem = + tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv0_cache_smem->setMemoryType(MemoryType::Shared); + tv1_cache_smem->setMemoryType(MemoryType::Shared); + + constexpr int64_t BSX = 64; + constexpr int64_t TSX = 32; + + // Step 0: [M, K, N] + // Step 1: [M, K, N/BSX, BSX] + tv6->split(-1, BSX); + + // Step 2: [M, K, N/BSX, BSX/TSX, TSX] + tv6->split(-1, TSX); + + // Step 3: [M, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(1, BSX); + + // Step 4: [M/BSX, BSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(0, BSX); + + // Step 5:[M/BSX, BSX/TSX, TSX, K/BSX, BSX, N/BSX, BSX/TSX, TSX] + tv6->split(1, TSX); + + // Step 6: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + tv6->reorder( + {{4, 7}, {7, 6}, {6, 5}, {2, 4}, {1, 3}, {3, 2}, {5, 1}, {0, 0}}); + + // Step 7a: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX (reduce)] + // Step 7b: [M/BSX, N/BSX, K/BSX (reduce), BSX/TSX, BSX/TSX, TSX, TSX] + TensorView* tv6_rf = tv6->rFactor({-1}); + + TransformPropagatorWithCheck propagator(tv6_rf); + MaxLogicalDomainInfoSpanningTree(tv6_rf).traverse(&propagator); + + // Parallelize + tv5->axis(0)->parallelize(ParallelType::BIDx); + tv5->axis(1)->parallelize(ParallelType::BIDy); + tv5->axis(-3)->parallelize(ParallelType::TIDy); + tv5->axis(-1)->parallelize(ParallelType::TIDx); + + scheduler_utils::parallelizeAllLike(tv5); + + // (BSX/TSX * TSX * BSX) = 1024 floats = 4096 bytes * (number of buffers) + // Apply circular buffering to smem and local cache tensors + tv0_cache_smem->axis(-3)->parallelize(ParallelType::Bulk); + tv0_cache_smem->axis(-2)->parallelize(ParallelType::Bulk); + tv0_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + + tv1_cache_smem->axis(-3)->parallelize(ParallelType::Bulk); + tv1_cache_smem->axis(-2)->parallelize(ParallelType::Bulk); + tv1_cache_smem->axis(-1)->parallelize(ParallelType::Bulk); + + // Apply ParallelType::Bulk to global output tensor. + tv5->axis(-4)->parallelize(ParallelType::Bulk); + tv5->axis(-3)->parallelize(ParallelType::Bulk); + tv5->axis(-2)->parallelize(ParallelType::Bulk); + tv5->axis(-1)->parallelize(ParallelType::Bulk); + + // IterDomain: [M/BSX, N/BSX, K/BSX, BSX/TSX, BSX/TSX, TSX, TSX, BSX] + // Parallelization: BDX, BDY, K/BSX ||, BSX/TSX, BSX/TSX, TDY, TSX, TDX] + // 4 non-parallelized for-loops + inlineMost(); + + // Apply circular buffering after setting computeAt position + tv0_cache_local->circularBuffer(number_of_stages); + tv1_cache_local->circularBuffer(number_of_stages); + + tv0_cache_smem->circularBuffer(number_of_stages); + tv1_cache_smem->circularBuffer(number_of_stages); + + constexpr int64_t K = 1024; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, K}, options); + at::Tensor t1 = at::randn({K, tensor_inner_dim}, options); + at::Tensor aten_output = + (t0.unsqueeze(/*dim=*/-1) * t1.unsqueeze(/*dim=*/0)).sum(/*dim=*/1); + + FusionExecutor fe; + fe.compileFusion(fusion.get(), {t0, t1}); + + std::vector cg_outputs = fe.runFusion({t0, t1}); + compare( + tensor_outer_dim, tensor_inner_dim, cg_outputs.front(), aten_output); + testValidate( + fusion.get(), cg_outputs, {t0, t1}, {aten_output}, __LINE__, __FILE__); +} + +// Test circular buffer from 2 to 5 stages +INSTANTIATE_TEST_SUITE_P( + Hopper, + TmaCircularBufferingTest, + testing::Combine( + ::testing::Range(2, 5), + testing::Values(128, 500, 1024), + testing::Values(128, 1024))); + } // namespace nvfuser