diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 9fdf0445f51..bee61c46873 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -89,7 +89,9 @@ Expr* initializeMbarrier( // threads in the CTA. num_of_arrives = SimplifyingIrBuilder::maybeCastExpr( DataType::UInt32, - GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock()); + GpuLower::current() + ->parallelDimensionMap() + .getNumComputeThreadsEachBlock()); } // Initialize mbarrier for each circular buffer stage. Use the thread diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 19ebe3c6e63..15ed808b936 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -106,6 +106,10 @@ class CircularBufferLoopCloner : public kir::IrVisitor { SimplifyingIrBuilder::create(opt.prefetch, DataType::Index)); break; } + case CircularBufferLoopStage::LoadWarp: + case CircularBufferLoopStage::ComputeWarp: { + break; + } default: { NVF_THROW("Unsupported loop mode, got: ", loop_type_); } @@ -1246,11 +1250,22 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - auto hasCpAsyncBulk = std::any_of( + auto has_cp_async_bulk = std::any_of( it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); - if (hasCpAsyncBulk) { - insertTma(loop, it->second); + bool use_warp_specialization = std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); + if (use_warp_specialization) { + NVF_ERROR( + std::all_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), + "In order to use warp specialization, all buffers must be loaded by TMA"); + insertTmaWarpSpecialized(loop, it->second); + } else if (has_cp_async_bulk) { + insertTmaPipelined(loop, it->second); } else { insert(loop, it->second); } @@ -1315,7 +1330,7 @@ class CircularBufferInserter : private kir::ExprMutator { .usesMBarrierForWAR(); } - void insertTma( + void insertTmaPipelined( ForLoop* circular_buffer_loop, const std::vector& loads) { // Arrive on the WAR mbarriers to let the prefetching start. @@ -1363,6 +1378,39 @@ class CircularBufferInserter : private kir::ExprMutator { registerInsertAfter(circular_buffer_loop, epilogue_loop); } + void insertTmaWarpSpecialized( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + const auto& opt = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + ParallelType warp_specialize_on = std::get(opt.type).on; + + kir::IfThenElse* warp_dispatch_ite = IrBuilder::create( + IrBuilder::create(IrBuilder::eqExpr( + NamedScalar::getParallelIndex(warp_specialize_on), + IrBuilder::subExpr( + GpuLower::current()->parallelDimensionMap().get( + warp_specialize_on), + circular_buffer_loop->fusion()->oneVal())))); + + // Load loop: + ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp); + warp_dispatch_ite->thenBody().push_back(load_loop); + + // Prefetch: + auto prefetch_loop = createArrivesForWar(circular_buffer_loop); + warp_dispatch_ite->elseBody().push_back(prefetch_loop); + + // Compute loop: + ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::ComputeWarp); + warp_dispatch_ite->elseBody().push_back(compute_loop); + + registerReplace(circular_buffer_loop, warp_dispatch_ite); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { NVF_ERROR( !usesMBarrierForWAR(circular_buffer_loop), diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 9093cc378d3..4e2f55323be 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -468,6 +468,18 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { last_writes_.pop_front(); // Found that a sync is needed + if (!sync_bitmap.hasBID() && + std::all_of( + expr->inputs().begin(), expr->inputs().end(), [](Val* val) { + return !val->isA() || + val->as()->getMemoryType() != + MemoryType::Shared || + ir_utils::isCpAsyncBulkLoad(val->definition()); + })) { + // RAW of TMA is handled separately, so skip it here. + return; + } + // TODO: Explicitly test the 3 cases below Expr* sync_expr = nullptr; kir::Allocate* maybe_alloc = nullptr; diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index c61dad11601..afff481fb94 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -246,17 +246,41 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { IrBuilder::create( UnaryOpType::ElectSync, elect_sync_val, full_mask_val); + auto load_warp_loop_it = + std::find_if(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) { + return fl->circularBufferLoopStage() == + CircularBufferLoopStage::LoadWarp; + }); + ParallelType load_warp_on = ParallelType::Serial; + if (load_warp_loop_it != for_loops_.end()) { + load_warp_on = std::get( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor( + (*load_warp_loop_it)->iter_domain()) + .type) + .on; + } + + // If we are in a load warp, then the warp-dispatching IfThenElse + // already selects on `load_warp_on`, so we should not generate + // predicates for it here. const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); - Val* first_warp = IrBuilder::ltExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + Val* conditional = load_warp_on == ParallelType::TIDx + ? pred->fusion()->trueVal() + : SimplifyingIrBuilder::logicalAndExpr( + elect_sync_val, + IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), + warp_size)); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (pdim_map.has(pt)) { - first_warp = SimplifyingIrBuilder::logicalAndExpr( - first_warp, + if (pdim_map.has(pt) && load_warp_on != pt) { + conditional = SimplifyingIrBuilder::logicalAndExpr( + conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); } } - return SimplifyingIrBuilder::logicalAndExpr(first_warp, elect_sync_val); + return conditional; } default: break; diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index dc0a92b5602..4e53a6207cc 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -113,7 +113,7 @@ class TVDomainGuard; // // /load 0;\ \. -// / load 1; [prefetch = 3] | [prologue] +// / load 1; [prefetch = 3] | [prefetching] // [stage] load 2;/ /' // [ = 6 ] load 3; wait load 0; compute 0; \. // \ load 4; wait load 1; compute 1; | @@ -123,7 +123,7 @@ class TVDomainGuard; // load 2; wait load 5; compute 5; wait compute 3; | // load 3; wait load 0; compute 0; wait compute 4; | // load 4; wait load 1; compute 1; wait compute 5; | [main] -// load 5; wait load 2; compute 2; wait compute 0; | [loop] +// load 5; wait load 2; compute 2; wait compute 0; | // .................................................. | // .................................................. | // .................................................. | @@ -132,7 +132,7 @@ class TVDomainGuard; // load ; wait load ; compute ; wait compute ; | // load ; wait load ; compute ; /' // /wait load ; compute ; \. -// [same number as prefetch] wait load ; compute ; | [epilogue] +// [same number as prefetch] wait load ; compute ; | [draining] // \wait load ; compute ; wait all computes; /' // clang-format on @@ -142,19 +142,37 @@ class TVDomainGuard; // load pipeline depth = prefetch + 1 // compute pipeline depth = stage - prefetch // -// The above timeline can be implemented as the following loop structure: +// There are two ways to implement the above timeline: pipelined, and +// warp-specialization. +// +// In the pipelined way, the prefetching stage is implemented as a prologue +// loop, and main stage is implemented as a main loop, and the draining stage is +// implemented as an epilogue loop. That is, we will have the following loop +// structure: // // Prologue loop: // for i in range(prefetch): // load data[i] to buffer[i] // -// Main loop: +// Main loop (using syncthreads to avoid WAR harzard): // for i in range(data.size - prefetch): // load data[i + prefetch] to buffer[(i + prefetch) % stage] -// wait buffer[i % stage] to be ready +// wait buffer[i % stage] to be loaded // compute buffer[i % stage] // wait until the first compute in the queue is done // (i.e. stage - prefetch - 1 in flight computes remaining) +// __syncthreads(); +// +// Main loop (using mbarrier to avoid WAR harzard): +// for i in range(data.size - prefetch): +// wait buffer[(i + prefetch) % stage] to be empty +// load data[i + prefetch] to buffer[(i + prefetch) % stage] +// wait buffer[i % stage] to be loaded +// compute buffer[i % stage] +// wait until the first compute in the queue is done +// (i.e. stage - prefetch - 1 in flight computes remaining) +// signal that buffer (i + prefetch + 1) % stage is empty and ready to be +// loaded again // // Epilogue loop: // for i in range(data.size - prefetch, data.size): @@ -166,6 +184,30 @@ class TVDomainGuard; // stage - prefetch - 1 iterations and last iteration of the main loop is // redundant. We can remove them to further optimize the performance, but // we decide to keep them for simplicity. +// +// In the warp-specialized approach, we will use different warp/warp-group +// for loading and computing. We will generate code like below (assuming warp +// specialized on TIDy): +// +// if (threadIdx.y == blockDim.y - 1) { +// // If we use warp specialization on TIDy, then the blockDim.y of the +// // kernel will be (whatever_value_inferred_from_schedule + 1), and the +// // last threadIdx.y will be used as load warp +// for i in range(data.size): +// wait buffer[i % stage] to be empty +// load data[i] to buffer[i % stage] +// } else { +// // Every threadIdx.y other than the last will be used for compute +// for i in range(prefetch + 1): +// signal that buffer i % stage is empty and ready to load +// for i in range(data.size): +// wait buffer[i % stage] to be loaded +// compute buffer[i % stage] +// wait until the first compute in the queue is done +// (i.e. stage - prefetch - 1 in flight computes remaining) +// signal that buffer (i + prefetch + 1) % stage is empty and ready to be +// loaded again +// } struct Pipelined { bool uses_mbarrier_for_war = false; @@ -184,7 +226,36 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { return os << "Pipelined"; } -using CircularBufferType = std::variant; +struct WarpSpecialized { + ParallelType on; + explicit WarpSpecialized(ParallelType on) : on(on) {} + WarpSpecialized() = default; + bool operator==(const WarpSpecialized& other) const { + return on == other.on; + } +}; + +inline std::ostream& operator<<( + std::ostream& os, + const WarpSpecialized& warp_specialized) { + std::string parallel_type_str = ""; + switch (warp_specialized.on) { + case ParallelType::TIDx: + parallel_type_str = "TIDx"; + break; + case ParallelType::TIDy: + parallel_type_str = "TIDy"; + break; + case ParallelType::TIDz: + parallel_type_str = "TIDz"; + break; + default: + NVF_THROW("Invalid parallel type"); + } + return os << "WarpSpecializedOn" << parallel_type_str; +} + +using CircularBufferType = std::variant; inline std::ostream& operator<<( std::ostream& os, @@ -207,8 +278,9 @@ struct CircularBufferOptions { } bool usesMBarrierForWAR() const { - return std::holds_alternative(type) && - std::get(type).uses_mbarrier_for_war; + return (std::holds_alternative(type) && + std::get(type).uses_mbarrier_for_war) || + std::holds_alternative(type); return false; } diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 7e8df469d09..1878fdda59b 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -38,9 +38,17 @@ struct hash { namespace nvfuser { void ParallelDimensionMap::build(Fusion* fusion) { + VectorOfUniqueEntries warp_specialized_types; VectorOfUniqueEntries all_concrete_ids; auto all_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(all_vals)) { + if (tv->isCircularBuffered() && + std::holds_alternative( + tv->circularBufferOptions().type)) { + const auto& warp_specialized = + std::get(tv->circularBufferOptions().type); + warp_specialized_types.pushBack(warp_specialized.on); + } for (auto id : tv->domain()->allIDs()) { auto ptype = id->getParallelType(); if (!isParallelTypeThread(ptype)) { @@ -83,6 +91,10 @@ void ParallelDimensionMap::build(Fusion* fusion) { } adjustMappingsForWarpPadding(); + + for (auto pt : warp_specialized_types) { + setWarpSpecializeOn(pt); + } } void ParallelDimensionMap::adjustMappingsForWarpPadding() { @@ -137,6 +149,17 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } +void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) { + auto dim_it = dim_map_.find(pt); + if (dim_it == dim_map_.end()) { + dim_map_[pt] = IrBuilder::create(2, DataType::Index); + } else { + dim_map_[pt] = SimplifyingIrBuilder::addExpr(dim_it->second, 1); + } + exact_types_.erase(pt); + warp_specialized_types_.insert(pt); +} + Val* ParallelDimensionMap::getRaw(ParallelType pt) const { NVF_ERROR(isParallelTypeThread(pt), "Invalid ParallelType: ", pt); auto it = dim_map_.find(pt); @@ -159,13 +182,16 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const { return exact_types_.find(pt) != exact_types_.end(); } -Val* ParallelDimensionMap::getNumThreadsEachBlock() const { +Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { Val* num_threads = FusionGuard::getCurFusion()->oneVal(); for (auto pt : kParallelTypeTIDs) { auto dim = getRaw(pt); if (dim == nullptr) { continue; } + if (warp_specialized_types_.find(pt) != warp_specialized_types_.end()) { + dim = SimplifyingIrBuilder::addExpr(dim, -1); + } num_threads = SimplifyingIrBuilder::mulExpr(num_threads, dim); } return num_threads; diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 7699bfffd3e..0e25182c12e 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -41,7 +41,19 @@ class ParallelDimensionMap { return dim_map_; } - Val* getNumThreadsEachBlock() const; + //! Get the number of threads per each CTA used for computation. When there is + //! no warp specialization, the result is trivial: it is just the product of + //! parallel dimensions of TIDx, TIDy and TIDz. If we do have warp + //! specialization, this returns the number of threads used for computing. For + //! example, if we have a simple kernel warp specialized on TIDy and all the + //! TIDx parallelized IterDomains have extent 32, and all the TIDy + //! parallelized IterDomains have extent 16, and there is no TIDz + //! parallelization, then we will have: + //! blockDim = (x=32, y=17, z=1) + //! And this function will return (32 * 16) because the extra one for TIDy is + //! introduced by warp specialization and only used for loading circular + //! buffer tensors. + Val* getNumComputeThreadsEachBlock() const; bool has(ParallelType pt) const { return dim_map_.count(pt) > 0; @@ -52,6 +64,11 @@ class ParallelDimensionMap { //! multiple of the warp size. void adjustMappingsForWarpPadding(); + //! If we are doing warp specialization on pt, then we need to increase + //! the parallel dimension size of pt by one, where the extra one is used + //! as the load warp. In this case, pt becomes non-exact. + void setWarpSpecializeOn(ParallelType pt); + private: //! Maps from parallel types to dimensions, which are constant if //! a unique value is found. @@ -59,6 +76,9 @@ class ParallelDimensionMap { //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; + + //! Set of parallel types that we are doing warp specialization on + std::unordered_set warp_specialized_types_; }; } // namespace nvfuser diff --git a/csrc/type.cpp b/csrc/type.cpp index c216fd63f74..ab087361a1d 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -1464,6 +1464,12 @@ std::ostream& operator<<( case CircularBufferLoopStage::Epilog: os << "{CircularBufferEpilog}"; break; + case CircularBufferLoopStage::LoadWarp: + os << "{LoadWarp}"; + break; + case CircularBufferLoopStage::ComputeWarp: + os << "{ComputeWarp}"; + break; default: NVF_THROW("unknown circular buffer stage"); } diff --git a/csrc/type.h b/csrc/type.h index 65a5c65e4c9..89cebe8763b 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -773,6 +773,8 @@ enum class CircularBufferLoopStage { Prolog = 0, Main, Epilog, + LoadWarp, + ComputeWarp, EndOfStages, // A special placeholder used to iterate over all stages NotApplicable }; @@ -782,7 +784,8 @@ enum class CircularBufferLoopStage { // e.g., No additional loads are required for the Epilogue stage. inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Prolog || - stage == CircularBufferLoopStage::Main; + stage == CircularBufferLoopStage::Main || + stage == CircularBufferLoopStage::LoadWarp; } // The consuming expressions of circular buffer are cloned for these circular @@ -790,7 +793,8 @@ inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { // e.g., No actual computation occurs in the Prologue stage. inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::Epilog; + stage == CircularBufferLoopStage::Epilog || + stage == CircularBufferLoopStage::ComputeWarp; } // A loop type may have WAR hazard if any of the following is true: @@ -800,7 +804,9 @@ inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { // properly handled, could be overwriten by a circular buffer loading // somewhere (*may or may not be in this loop*) inline bool mayHaveWarHazard(CircularBufferLoopStage stage) { - return stage == CircularBufferLoopStage::Main; + return stage == CircularBufferLoopStage::Main || + stage == CircularBufferLoopStage::LoadWarp || + stage == CircularBufferLoopStage::ComputeWarp; } //! Supported swizzle types, diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index fd601027551..725f2c128b9 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1266,11 +1266,16 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { tv3->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); - // Circular Buffer with set operation - // Load TV1 into shared memory - tv4->setMemoryType(MemoryType::Shared); - tv4->circularBuffer( - number_of_stages, prefetch_distance, circular_buffer_type); + // Circular Buffer with set operation. + // Note that in order to use warp specialization, all circular buffers must be + // loaded by TMA, so for this test we disable circular buffering of set op if + // we are testing warp specialization. + if (!std::holds_alternative(circular_buffer_type)) { + // Load TV1 into shared memory + tv4->setMemoryType(MemoryType::Shared); + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + } // Split reference to parallelize TMA tile reference->split(-1, 32); @@ -1361,6 +1366,12 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) { TEST_P(TmaCircularBufferingTest, InnerReduction) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + if (std::holds_alternative(circular_buffer_type)) { + GTEST_SKIP() + << "This test uses block reduce, which uses hard-coded blockDim, " + << "which can cause deadlock when combined with warp specialization."; + } + std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1475,6 +1486,13 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { TEST_P(TmaCircularBufferingTest, Persistent) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + if (std::holds_alternative(circular_buffer_type)) { + GTEST_SKIP() + << "This test uses block reduce and block broadcast, " + << "which has hard-coded blockDim, " + << "which can cause deadlock when combined with warp specialization."; + } + constexpr at::ScalarType dtype = at::ScalarType::Float; constexpr int64_t correction = 0; constexpr int64_t reduction_axis = 1; @@ -1847,7 +1865,10 @@ auto tmaCircularBufferingParams() { // https://en.wikipedia.org/wiki/Lehmer_random_number_generator uint32_t lcg_parkmiller = 1; const std::vector all_types{ - Pipelined(false), Pipelined(true)}; + Pipelined(false), + Pipelined(true), + WarpSpecialized(ParallelType::TIDx), + WarpSpecialized(ParallelType::TIDy)}; std::vector values; for (int64_t i : {2, 4}) { for (int64_t j : c10::irange(-i, i)) {