From b605a22e218ee361ecda346f6942eb2ab2fe86a3 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 13:02:03 -0800 Subject: [PATCH 01/38] Add warp specialization as a circular buffering type --- csrc/device_lower/pass/allocation.cpp | 2 +- csrc/device_lower/pass/circular_buffer.cpp | 56 ++++++++++++++++++++-- csrc/device_lower/pass/insert_syncs.cpp | 37 +++++++++----- csrc/device_lower/pass/predicate.cpp | 16 ++++++- csrc/ir/interface_nodes.h | 24 ++++++++-- csrc/parallel_dimension_map.cpp | 29 ++++++++++- csrc/parallel_dimension_map.h | 5 +- csrc/type.cpp | 6 +++ csrc/type.h | 12 +++-- 9 files changed, 161 insertions(+), 26 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 9fdf0445f51..90f80a14f2c 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -89,7 +89,7 @@ Expr* initializeMbarrier( // threads in the CTA. num_of_arrives = SimplifyingIrBuilder::maybeCastExpr( DataType::UInt32, - GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlock()); + GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlockIgnoringWarpSpecialization()); } // Initialize mbarrier for each circular buffer stage. Use the thread diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 19ebe3c6e63..15ed808b936 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -106,6 +106,10 @@ class CircularBufferLoopCloner : public kir::IrVisitor { SimplifyingIrBuilder::create(opt.prefetch, DataType::Index)); break; } + case CircularBufferLoopStage::LoadWarp: + case CircularBufferLoopStage::ComputeWarp: { + break; + } default: { NVF_THROW("Unsupported loop mode, got: ", loop_type_); } @@ -1246,11 +1250,22 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - auto hasCpAsyncBulk = std::any_of( + auto has_cp_async_bulk = std::any_of( it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); - if (hasCpAsyncBulk) { - insertTma(loop, it->second); + bool use_warp_specialization = std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); + if (use_warp_specialization) { + NVF_ERROR( + std::all_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), + "In order to use warp specialization, all buffers must be loaded by TMA"); + insertTmaWarpSpecialized(loop, it->second); + } else if (has_cp_async_bulk) { + insertTmaPipelined(loop, it->second); } else { insert(loop, it->second); } @@ -1315,7 +1330,7 @@ class CircularBufferInserter : private kir::ExprMutator { .usesMBarrierForWAR(); } - void insertTma( + void insertTmaPipelined( ForLoop* circular_buffer_loop, const std::vector& loads) { // Arrive on the WAR mbarriers to let the prefetching start. @@ -1363,6 +1378,39 @@ class CircularBufferInserter : private kir::ExprMutator { registerInsertAfter(circular_buffer_loop, epilogue_loop); } + void insertTmaWarpSpecialized( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + const auto& opt = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + ParallelType warp_specialize_on = std::get(opt.type).on; + + kir::IfThenElse* warp_dispatch_ite = IrBuilder::create( + IrBuilder::create(IrBuilder::eqExpr( + NamedScalar::getParallelIndex(warp_specialize_on), + IrBuilder::subExpr( + GpuLower::current()->parallelDimensionMap().get( + warp_specialize_on), + circular_buffer_loop->fusion()->oneVal())))); + + // Load loop: + ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::LoadWarp); + warp_dispatch_ite->thenBody().push_back(load_loop); + + // Prefetch: + auto prefetch_loop = createArrivesForWar(circular_buffer_loop); + warp_dispatch_ite->elseBody().push_back(prefetch_loop); + + // Compute loop: + ForLoop* compute_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, loads, CircularBufferLoopStage::ComputeWarp); + warp_dispatch_ite->elseBody().push_back(compute_loop); + + registerReplace(circular_buffer_loop, warp_dispatch_ite); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { NVF_ERROR( !usesMBarrierForWAR(circular_buffer_loop), diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 9093cc378d3..b07f227dfcd 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -204,7 +204,8 @@ class WarSyncInserter : private kir::ExprMutator { auto out_tvs = ir_utils::filterByType(expr->outputs()); for (auto out_tv : out_tvs) { if (out_tv->getMemoryType() != MemoryType::Shared || - GpuLower::current()->syncMap()->needsRawSync(out_tv).none()) { + GpuLower::current()->syncMap()->needsRawSync(out_tv).none() || + out_tv->isCircularBuffered()) { continue; } @@ -222,7 +223,8 @@ class WarSyncInserter : private kir::ExprMutator { auto inp_tvs = ir_utils::filterByType(expr->inputs()); for (auto inp_tv : inp_tvs) { if (inp_tv->getMemoryType() != MemoryType::Shared || - GpuLower::current()->syncMap()->needsRawSync(inp_tv).none()) { + GpuLower::current()->syncMap()->needsRawSync(inp_tv).none() || + inp_tv->isCircularBuffered()) { continue; } @@ -955,9 +957,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { const auto gpu_lower = GpuLower::current(); int64_t pending_ops = std::numeric_limits::max(); for (auto inp : expr->inputs()) { - if (async_inputs_in_current_scope_.count(inp) == 0) { - continue; - } + // if (async_inputs_in_current_scope_.count(inp) == 0) { + // continue; + // } auto tv = dynamic_cast(inp); if (tv == nullptr) { continue; @@ -973,7 +975,8 @@ class WarAsyncWaitInserter : private kir::ExprMutator { } auto stage = circular_buffer_loop->circularBufferLoopStage(); NVF_ERROR( - stage == CircularBufferLoopStage::Main, + stage == CircularBufferLoopStage::Main || + stage == CircularBufferLoopStage::ComputeWarp, "Only main circular buffer loop needs WAR async wait, ", "so the code should not reach here. Stage:", stage); @@ -983,6 +986,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { circular_buffer_loop->iter_domain()); pending_ops = std::min(pending_ops, opt.stage - opt.prefetch - 1); } + NVF_ERROR(pending_ops != std::numeric_limits::max()); return pending_ops; } @@ -1005,7 +1009,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { auto expr = *it; // If the input of the async op is not in the current scope, then this // async op is not related, so nothing to protect. - if (std::none_of( + if (for_loop->circularBufferLoopStage() != + CircularBufferLoopStage::ComputeWarp && + std::none_of( expr->inputs().begin(), expr->inputs().end(), [&](Val* val) { return async_inputs_in_current_scope_.count(val); })) { @@ -1031,10 +1037,19 @@ class WarAsyncWaitInserter : private kir::ExprMutator { for (auto [type, pending_ops] : types_and_pending_ops_to_protect) { auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops); while (!sync_exprs.empty()) { - registerInsertAfter( - for_loop->body().exprs().back(), - sync_exprs.back(), - &for_loop->body()); + // TODO: wrong + if (for_loop->circularBufferLoopStage() == CircularBufferLoopStage::ComputeWarp) { + NVF_ERROR(for_loop->body().exprs().back()->isA()); + registerInsertAfter( + for_loop->body().exprs().at(for_loop->body().exprs().size() - 2), + sync_exprs.back(), + &for_loop->body()); + } else { + registerInsertAfter( + for_loop->body().exprs().back(), + sync_exprs.back(), + &for_loop->body()); + } sync_exprs.pop_back(); } } diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index c61dad11601..1ae5f204af4 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -246,11 +246,25 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { IrBuilder::create( UnaryOpType::ElectSync, elect_sync_val, full_mask_val); + auto load_warp_loop_it = + std::find(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) { + return fl->circularBufferLoopStage() == + CircularBufferLoopStage::LoadWarp; + }); + const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); Val* first_warp = IrBuilder::ltExpr( NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (pdim_map.has(pt)) { + bool in_load_warp_for_pt = load_warp_loop_it != for_loops_.end() && + std::get( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor( + (*load_warp_loop_it)->iter_domain()) + .type) + .on == pt; + if (pdim_map.has(pt) && !in_load_warp_for_pt) { first_warp = SimplifyingIrBuilder::logicalAndExpr( first_warp, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index dc0a92b5602..dee87dc343b 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -74,7 +74,7 @@ class TVDomainGuard; } // [Circular buffering] -// + // A non-circle-buffered loop looks like below (assuming both the load and the // compute are async ops): // for i in range(data.size): @@ -184,7 +184,22 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { return os << "Pipelined"; } -using CircularBufferType = std::variant; +// For example, if `on` is TIDy, then will assign additional TIDy for cirular +// buffer loading. +struct WarpSpecialized { + ParallelType on; + bool operator==(const WarpSpecialized& other) const { + return on == other.on; + } +}; + +inline std::ostream& operator<<( + std::ostream& os, + const WarpSpecialized& warp_specialized) { + return os << "WarpSpecializedOn" << warp_specialized.on; +} + +using CircularBufferType = std::variant; inline std::ostream& operator<<( std::ostream& os, @@ -207,8 +222,9 @@ struct CircularBufferOptions { } bool usesMBarrierForWAR() const { - return std::holds_alternative(type) && - std::get(type).uses_mbarrier_for_war; + return (std::holds_alternative(type) && + std::get(type).uses_mbarrier_for_war) || + std::holds_alternative(type); return false; } diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 7e8df469d09..4274bd3dfc8 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -38,9 +38,17 @@ struct hash { namespace nvfuser { void ParallelDimensionMap::build(Fusion* fusion) { + VectorOfUniqueEntries warp_specialized_types; VectorOfUniqueEntries all_concrete_ids; auto all_vals = fusion->usedMathVals(); for (auto tv : ir_utils::filterByType(all_vals)) { + if (tv->isCircularBuffered() && + std::holds_alternative( + tv->circularBufferOptions().type)) { + const auto& warp_specialized = + std::get(tv->circularBufferOptions().type); + warp_specialized_types.pushBack(warp_specialized.on); + } for (auto id : tv->domain()->allIDs()) { auto ptype = id->getParallelType(); if (!isParallelTypeThread(ptype)) { @@ -83,6 +91,10 @@ void ParallelDimensionMap::build(Fusion* fusion) { } adjustMappingsForWarpPadding(); + + for (auto pt : warp_specialized_types) { + setWarpSpecializeOn(pt); + } } void ParallelDimensionMap::adjustMappingsForWarpPadding() { @@ -137,6 +149,18 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { exact_types_.erase(ParallelType::TIDx); } +void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) { + std::cout << "Warp specialize on: " << pt << std::endl; + auto dim_it = dim_map_.find(pt); + if (dim_it == dim_map_.end()) { + dim_map_[pt] = IrBuilder::create(2, DataType::Index); + } else { + dim_map_[pt] = SimplifyingIrBuilder::addExpr(dim_it->second, 1); + } + exact_types_.erase(pt); + warp_specialized_types_.insert(pt); +} + Val* ParallelDimensionMap::getRaw(ParallelType pt) const { NVF_ERROR(isParallelTypeThread(pt), "Invalid ParallelType: ", pt); auto it = dim_map_.find(pt); @@ -159,13 +183,16 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const { return exact_types_.find(pt) != exact_types_.end(); } -Val* ParallelDimensionMap::getNumThreadsEachBlock() const { +Val* ParallelDimensionMap::getNumThreadsEachBlockIgnoringWarpSpecialization() const { Val* num_threads = FusionGuard::getCurFusion()->oneVal(); for (auto pt : kParallelTypeTIDs) { auto dim = getRaw(pt); if (dim == nullptr) { continue; } + if (warp_specialized_types_.find(pt) != warp_specialized_types_.end()) { + dim = SimplifyingIrBuilder::addExpr(dim, -1); + } num_threads = SimplifyingIrBuilder::mulExpr(num_threads, dim); } return num_threads; diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 7699bfffd3e..5fce9050eed 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -41,7 +41,7 @@ class ParallelDimensionMap { return dim_map_; } - Val* getNumThreadsEachBlock() const; + Val* getNumThreadsEachBlockIgnoringWarpSpecialization() const; bool has(ParallelType pt) const { return dim_map_.count(pt) > 0; @@ -52,6 +52,8 @@ class ParallelDimensionMap { //! multiple of the warp size. void adjustMappingsForWarpPadding(); + void setWarpSpecializeOn(ParallelType pt); + private: //! Maps from parallel types to dimensions, which are constant if //! a unique value is found. @@ -59,6 +61,7 @@ class ParallelDimensionMap { //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; + std::unordered_set warp_specialized_types_; }; } // namespace nvfuser diff --git a/csrc/type.cpp b/csrc/type.cpp index c216fd63f74..ab087361a1d 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -1464,6 +1464,12 @@ std::ostream& operator<<( case CircularBufferLoopStage::Epilog: os << "{CircularBufferEpilog}"; break; + case CircularBufferLoopStage::LoadWarp: + os << "{LoadWarp}"; + break; + case CircularBufferLoopStage::ComputeWarp: + os << "{ComputeWarp}"; + break; default: NVF_THROW("unknown circular buffer stage"); } diff --git a/csrc/type.h b/csrc/type.h index 65a5c65e4c9..89cebe8763b 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -773,6 +773,8 @@ enum class CircularBufferLoopStage { Prolog = 0, Main, Epilog, + LoadWarp, + ComputeWarp, EndOfStages, // A special placeholder used to iterate over all stages NotApplicable }; @@ -782,7 +784,8 @@ enum class CircularBufferLoopStage { // e.g., No additional loads are required for the Epilogue stage. inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Prolog || - stage == CircularBufferLoopStage::Main; + stage == CircularBufferLoopStage::Main || + stage == CircularBufferLoopStage::LoadWarp; } // The consuming expressions of circular buffer are cloned for these circular @@ -790,7 +793,8 @@ inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { // e.g., No actual computation occurs in the Prologue stage. inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::Epilog; + stage == CircularBufferLoopStage::Epilog || + stage == CircularBufferLoopStage::ComputeWarp; } // A loop type may have WAR hazard if any of the following is true: @@ -800,7 +804,9 @@ inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { // properly handled, could be overwriten by a circular buffer loading // somewhere (*may or may not be in this loop*) inline bool mayHaveWarHazard(CircularBufferLoopStage stage) { - return stage == CircularBufferLoopStage::Main; + return stage == CircularBufferLoopStage::Main || + stage == CircularBufferLoopStage::LoadWarp || + stage == CircularBufferLoopStage::ComputeWarp; } //! Supported swizzle types, From 6a11cdd740892fd5df57ea0c7d14041c28b35d30 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 13:06:00 -0800 Subject: [PATCH 02/38] revert --- csrc/device_lower/pass/insert_syncs.cpp | 37 ++++++++----------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index b07f227dfcd..9093cc378d3 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -204,8 +204,7 @@ class WarSyncInserter : private kir::ExprMutator { auto out_tvs = ir_utils::filterByType(expr->outputs()); for (auto out_tv : out_tvs) { if (out_tv->getMemoryType() != MemoryType::Shared || - GpuLower::current()->syncMap()->needsRawSync(out_tv).none() || - out_tv->isCircularBuffered()) { + GpuLower::current()->syncMap()->needsRawSync(out_tv).none()) { continue; } @@ -223,8 +222,7 @@ class WarSyncInserter : private kir::ExprMutator { auto inp_tvs = ir_utils::filterByType(expr->inputs()); for (auto inp_tv : inp_tvs) { if (inp_tv->getMemoryType() != MemoryType::Shared || - GpuLower::current()->syncMap()->needsRawSync(inp_tv).none() || - inp_tv->isCircularBuffered()) { + GpuLower::current()->syncMap()->needsRawSync(inp_tv).none()) { continue; } @@ -957,9 +955,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { const auto gpu_lower = GpuLower::current(); int64_t pending_ops = std::numeric_limits::max(); for (auto inp : expr->inputs()) { - // if (async_inputs_in_current_scope_.count(inp) == 0) { - // continue; - // } + if (async_inputs_in_current_scope_.count(inp) == 0) { + continue; + } auto tv = dynamic_cast(inp); if (tv == nullptr) { continue; @@ -975,8 +973,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { } auto stage = circular_buffer_loop->circularBufferLoopStage(); NVF_ERROR( - stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::ComputeWarp, + stage == CircularBufferLoopStage::Main, "Only main circular buffer loop needs WAR async wait, ", "so the code should not reach here. Stage:", stage); @@ -986,7 +983,6 @@ class WarAsyncWaitInserter : private kir::ExprMutator { circular_buffer_loop->iter_domain()); pending_ops = std::min(pending_ops, opt.stage - opt.prefetch - 1); } - NVF_ERROR(pending_ops != std::numeric_limits::max()); return pending_ops; } @@ -1009,9 +1005,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { auto expr = *it; // If the input of the async op is not in the current scope, then this // async op is not related, so nothing to protect. - if (for_loop->circularBufferLoopStage() != - CircularBufferLoopStage::ComputeWarp && - std::none_of( + if (std::none_of( expr->inputs().begin(), expr->inputs().end(), [&](Val* val) { return async_inputs_in_current_scope_.count(val); })) { @@ -1037,19 +1031,10 @@ class WarAsyncWaitInserter : private kir::ExprMutator { for (auto [type, pending_ops] : types_and_pending_ops_to_protect) { auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops); while (!sync_exprs.empty()) { - // TODO: wrong - if (for_loop->circularBufferLoopStage() == CircularBufferLoopStage::ComputeWarp) { - NVF_ERROR(for_loop->body().exprs().back()->isA()); - registerInsertAfter( - for_loop->body().exprs().at(for_loop->body().exprs().size() - 2), - sync_exprs.back(), - &for_loop->body()); - } else { - registerInsertAfter( - for_loop->body().exprs().back(), - sync_exprs.back(), - &for_loop->body()); - } + registerInsertAfter( + for_loop->body().exprs().back(), + sync_exprs.back(), + &for_loop->body()); sync_exprs.pop_back(); } } From 14361288d73865fc46c6cefe704e43707004e0a0 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 13:06:42 -0800 Subject: [PATCH 03/38] revert --- csrc/ir/interface_nodes.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index dee87dc343b..0ded1629943 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -74,7 +74,7 @@ class TVDomainGuard; } // [Circular buffering] - +// // A non-circle-buffered loop looks like below (assuming both the load and the // compute are async ops): // for i in range(data.size): From b1a873e4ab84beafa2a9d522cdf05a22c3d06549 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 13:13:28 -0800 Subject: [PATCH 04/38] fix --- csrc/device_lower/pass/predicate.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 1ae5f204af4..97ed99f86c4 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -247,7 +247,7 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { UnaryOpType::ElectSync, elect_sync_val, full_mask_val); auto load_warp_loop_it = - std::find(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) { + std::find_if(for_loops_.begin(), for_loops_.end(), [](ForLoop* fl) { return fl->circularBufferLoopStage() == CircularBufferLoopStage::LoadWarp; }); From 5b3399c7c95a4efbc96ca4a24d4da5ca01baeade Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 14:08:06 -0800 Subject: [PATCH 05/38] save --- csrc/ir/interface_nodes.h | 2 ++ tests/cpp/test_circular_buffering.cpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 0ded1629943..57745092505 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -188,6 +188,8 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { // buffer loading. struct WarpSpecialized { ParallelType on; + explicit WarpSpecialized(ParallelType on) : on(on) {} + WarpSpecialized() = default; bool operator==(const WarpSpecialized& other) const { return on == other.on; } diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index d0d0209cfb4..d5dd8c01831 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1784,7 +1784,7 @@ auto tmaCircularBufferingParams() { // https://en.wikipedia.org/wiki/Lehmer_random_number_generator uint32_t lcg_parkmiller = 1; const std::vector all_types{ - Pipelined(false), Pipelined(true)}; + Pipelined(false), Pipelined(true), WarpSpecialized(ParallelType::TIDy)}; std::vector values; for (int64_t i : {2, 4}) { for (int64_t j : c10::irange(-i, i)) { From 28a0931cfd10ac8a1b9d118427c422a69d60d5c0 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 14:32:18 -0800 Subject: [PATCH 06/38] str --- csrc/ir/interface_nodes.h | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 57745092505..67e112fd643 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -198,7 +198,21 @@ struct WarpSpecialized { inline std::ostream& operator<<( std::ostream& os, const WarpSpecialized& warp_specialized) { - return os << "WarpSpecializedOn" << warp_specialized.on; + std::string parallel_type_str = ""; + switch (warp_specialized.on) { + case ParallelType::TIDx: + parallel_type_str = "TIDx"; + break; + case ParallelType::TIDy: + parallel_type_str = "TIDy"; + break; + case ParallelType::TIDz: + parallel_type_str = "TIDz"; + break; + default: + NVF_THROW("Invalid parallel type"); + } + return os << "WarpSpecializedOn" << parallel_type_str; } using CircularBufferType = std::variant; From 08c2357c6ae8342de076f9d3bd91c5a65b32bf60 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 14:43:04 -0800 Subject: [PATCH 07/38] save --- csrc/parallel_dimension_map.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 4274bd3dfc8..2792bd75e7d 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -150,7 +150,6 @@ void ParallelDimensionMap::adjustMappingsForWarpPadding() { } void ParallelDimensionMap::setWarpSpecializeOn(ParallelType pt) { - std::cout << "Warp specialize on: " << pt << std::endl; auto dim_it = dim_map_.find(pt); if (dim_it == dim_map_.end()) { dim_map_[pt] = IrBuilder::create(2, DataType::Index); From 4eadab7809f208709b5d8f4e1b3125422d7bc51f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 14:55:09 -0800 Subject: [PATCH 08/38] save --- csrc/device_lower/pass/insert_syncs.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 9093cc378d3..dbd6e0459b2 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -468,6 +468,10 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { last_writes_.pop_front(); // Found that a sync is needed + if (!sync_bitmap.hasBID()) { + std::cout << "expr: " << expr->toString() << std::endl; + } + // TODO: Explicitly test the 3 cases below Expr* sync_expr = nullptr; kir::Allocate* maybe_alloc = nullptr; From f12c433912b5514a7ea89325d4a2f8360849dc40 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:02:06 -0800 Subject: [PATCH 09/38] save --- csrc/device_lower/pass/insert_syncs.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index dbd6e0459b2..b1f7e766cf7 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -468,8 +468,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { last_writes_.pop_front(); // Found that a sync is needed - if (!sync_bitmap.hasBID()) { - std::cout << "expr: " << expr->toString() << std::endl; + if (std::all_of( + expr->inputs().begin(), expr->inputs().end(), [](Val* val) { + return ir_utils::isCpAsyncBulkLoad(tv->definition()); + })) { + return; } // TODO: Explicitly test the 3 cases below From b70537208928d6112fe7ae64ee6cc6eee3b34c5f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:03:29 -0800 Subject: [PATCH 10/38] fix --- csrc/device_lower/pass/insert_syncs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index b1f7e766cf7..167a742da85 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -470,7 +470,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { if (std::all_of( expr->inputs().begin(), expr->inputs().end(), [](Val* val) { - return ir_utils::isCpAsyncBulkLoad(tv->definition()); + return ir_utils::isCpAsyncBulkLoad(val->definition()); })) { return; } From 5dd18b2cb9849e5f74422566e5d045965a8ff112 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:04:47 -0800 Subject: [PATCH 11/38] save --- csrc/device_lower/pass/insert_syncs.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 167a742da85..bd56e5d4b8f 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -468,7 +468,8 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { last_writes_.pop_front(); // Found that a sync is needed - if (std::all_of( + if (!sync_bitmap.hasBID() && + std::all_of( expr->inputs().begin(), expr->inputs().end(), [](Val* val) { return ir_utils::isCpAsyncBulkLoad(val->definition()); })) { From 0eec8714f68ef7727890e8281076db827a92bb75 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:07:29 -0800 Subject: [PATCH 12/38] remove assert --- csrc/device_lower/pass/circular_buffer.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 15ed808b936..354f74617e9 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1259,10 +1259,6 @@ class CircularBufferInserter : private kir::ExprMutator { .getCircularBufferOptionsFor(loop->iter_domain()) .type); if (use_warp_specialization) { - NVF_ERROR( - std::all_of( - it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), - "In order to use warp specialization, all buffers must be loaded by TMA"); insertTmaWarpSpecialized(loop, it->second); } else if (has_cp_async_bulk) { insertTmaPipelined(loop, it->second); From ddae9d761b23a41fe4f951546e066eef0c3c8bb7 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:10:42 -0800 Subject: [PATCH 13/38] assert back --- csrc/device_lower/pass/circular_buffer.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 354f74617e9..15ed808b936 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1259,6 +1259,10 @@ class CircularBufferInserter : private kir::ExprMutator { .getCircularBufferOptionsFor(loop->iter_domain()) .type); if (use_warp_specialization) { + NVF_ERROR( + std::all_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), + "In order to use warp specialization, all buffers must be loaded by TMA"); insertTmaWarpSpecialized(loop, it->second); } else if (has_cp_async_bulk) { insertTmaPipelined(loop, it->second); From ae0122a652407ce2b387295336aefc13fca829b4 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:13:20 -0800 Subject: [PATCH 14/38] save --- tests/cpp/test_circular_buffering.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index d5dd8c01831..bbbc62cc8a1 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1261,10 +1261,15 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { tv3->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); - // Circular Buffer with set operation + // Circular Buffer with set operation. Note that in order to use warp + // specialization, all circilar buffers must be loaded by TMA, so for + // this test we disable circular buffering of set op if we are testing warp + // specialization. tv4->axis(0)->parallelize(ParallelType::BIDx); - tv4->circularBuffer( - number_of_stages, prefetch_distance, circular_buffer_type); + if (!std::holds_alternative(circular_buffer_type)) { + tv4->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + } // Split reference to parallelize TMA tile reference->split(-1, 32); From d30ec31cc0fba833a1376bfcda7e1f29d070708e Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:17:02 -0800 Subject: [PATCH 15/38] save --- tests/cpp/test_circular_buffering.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index bbbc62cc8a1..c404b763be1 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1237,10 +1237,6 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); tv3->setMemoryType(MemoryType::Shared); - // Load TV1 into shared memory - TensorView* tv4 = tv1->cacheAfter(); - tv4->setMemoryType(MemoryType::Shared); - TensorView* reference = tv2; // Constants @@ -1267,6 +1263,9 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { // specialization. tv4->axis(0)->parallelize(ParallelType::BIDx); if (!std::holds_alternative(circular_buffer_type)) { + // Load TV1 into shared memory + TensorView* tv4 = tv1->cacheAfter(); + tv4->setMemoryType(MemoryType::Shared); tv4->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); } From e2cb8d1b2460d5f1d47568089e5a1581dfb3272a Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:18:59 -0800 Subject: [PATCH 16/38] save --- tests/cpp/test_circular_buffering.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index c404b763be1..4d3c881d4fa 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1257,14 +1257,14 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { tv3->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); - // Circular Buffer with set operation. Note that in order to use warp - // specialization, all circilar buffers must be loaded by TMA, so for - // this test we disable circular buffering of set op if we are testing warp - // specialization. - tv4->axis(0)->parallelize(ParallelType::BIDx); + // Circular Buffer with set operation. + // Note that in order to use warp specialization, all circilar buffers must be + // loaded by TMA, so for this test we disable circular buffering of set op if + // we are testing warp specialization. if (!std::holds_alternative(circular_buffer_type)) { // Load TV1 into shared memory TensorView* tv4 = tv1->cacheAfter(); + tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->setMemoryType(MemoryType::Shared); tv4->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); From d8233b3ad9b4b05fd1540005d1c94542b4bdde23 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:20:18 -0800 Subject: [PATCH 17/38] comment --- csrc/device_lower/pass/insert_syncs.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index bd56e5d4b8f..1dc1b05c99e 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -473,6 +473,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { expr->inputs().begin(), expr->inputs().end(), [](Val* val) { return ir_utils::isCpAsyncBulkLoad(val->definition()); })) { + // RAW of TMA is handled separately, so skip it here. return; } From 26675e354ac6fd7dec087e7e2ec72329dc65cbca Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:22:55 -0800 Subject: [PATCH 18/38] save --- tests/cpp/test_circular_buffering.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 4d3c881d4fa..3569b6dc6ef 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1248,9 +1248,6 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { TransformPropagatorWithCheck propagator(reference); MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); - // Set computeAt position - inlineAllAt(tv2, /*pos=*/2); - // Circular Buffer with TMA loads tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(2)->parallelize(ParallelType::Bulk); @@ -1275,6 +1272,9 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { reference->axis(0)->parallelize(ParallelType::BIDx); reference->axis(-1)->parallelize(ParallelType::TIDx); + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); From 732072b464e9b78913d04a1ff56bb071f5cbbf3c Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:28:32 -0800 Subject: [PATCH 19/38] save --- tests/cpp/test_circular_buffering.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 3569b6dc6ef..b5fe908f7cb 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1237,6 +1237,9 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { TensorView* tv3 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); tv3->setMemoryType(MemoryType::Shared); + TensorView* tv4 = tv1->cacheAfter(); + tv4->axis(0)->parallelize(ParallelType::BIDx); + TensorView* reference = tv2; // Constants @@ -1248,6 +1251,9 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { TransformPropagatorWithCheck propagator(reference); MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + // Set computeAt position + inlineAllAt(tv2, /*pos=*/2); + // Circular Buffer with TMA loads tv3->axis(0)->parallelize(ParallelType::BIDx); tv3->axis(2)->parallelize(ParallelType::Bulk); @@ -1260,8 +1266,6 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { // we are testing warp specialization. if (!std::holds_alternative(circular_buffer_type)) { // Load TV1 into shared memory - TensorView* tv4 = tv1->cacheAfter(); - tv4->axis(0)->parallelize(ParallelType::BIDx); tv4->setMemoryType(MemoryType::Shared); tv4->circularBuffer( number_of_stages, prefetch_distance, circular_buffer_type); @@ -1272,9 +1276,6 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { reference->axis(0)->parallelize(ParallelType::BIDx); reference->axis(-1)->parallelize(ParallelType::TIDx); - // Set computeAt position - inlineAllAt(tv2, /*pos=*/2); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); at::Tensor t1 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); From ae756cc617bb70c03c4d03bac5a85a7354ddde2f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:39:57 -0800 Subject: [PATCH 20/38] save --- csrc/device_lower/pass/insert_syncs.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 1dc1b05c99e..4e2f55323be 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -471,7 +471,10 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { if (!sync_bitmap.hasBID() && std::all_of( expr->inputs().begin(), expr->inputs().end(), [](Val* val) { - return ir_utils::isCpAsyncBulkLoad(val->definition()); + return !val->isA() || + val->as()->getMemoryType() != + MemoryType::Shared || + ir_utils::isCpAsyncBulkLoad(val->definition()); })) { // RAW of TMA is handled separately, so skip it here. return; From 0e0964556a28a8c941642245df09f05950876239 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 15:49:42 -0800 Subject: [PATCH 21/38] save --- tests/cpp/test_circular_buffering.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index b5fe908f7cb..4c7c2271978 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1393,7 +1393,13 @@ TEST_P(TmaCircularBufferingTest, Reduction) { // Parallelize reference->axis(0)->parallelize(ParallelType::BIDx); - reference->axis(-1)->parallelize(ParallelType::TIDx); + + // Use block reduce if possible. + // Note that block reduce implies block sync, which can cause deadlock when + // combined with warp specialization. So we do serial reduction for this test. + if (!std::holds_alternative(circular_buffer_type)) { + reference->axis(-1)->parallelize(ParallelType::TIDx); + } // InlineMost automatically handles vectorize and tma dimensions inlineMost(); From 93b00ee08020ef22e59d8c404d3def58cae70c8c Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:05:28 -0800 Subject: [PATCH 22/38] outer reduction --- tests/cpp/test_circular_buffering.cpp | 72 ++++++++++++++++++++++++--- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 4c7c2271978..a474641f81f 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1357,9 +1357,15 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) { testValidate(fusion.get(), cg_outputs, {t0, t1}, {t2}, __LINE__, __FILE__); } -TEST_P(TmaCircularBufferingTest, Reduction) { +TEST_P(TmaCircularBufferingTest, InnerReduction) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + if (std::holds_alternative(circular_buffer_type)) { + GTEST_SKIP() + << "This test uses block reduce, which implies block sync, " + << "which can cause deadlock when combined with warp specialization. " + } + std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -1394,12 +1400,8 @@ TEST_P(TmaCircularBufferingTest, Reduction) { // Parallelize reference->axis(0)->parallelize(ParallelType::BIDx); - // Use block reduce if possible. - // Note that block reduce implies block sync, which can cause deadlock when - // combined with warp specialization. So we do serial reduction for this test. - if (!std::holds_alternative(circular_buffer_type)) { - reference->axis(-1)->parallelize(ParallelType::TIDx); - } + // Use block reduce. + reference->axis(-1)->parallelize(ParallelType::TIDx); // InlineMost automatically handles vectorize and tma dimensions inlineMost(); @@ -1422,6 +1424,62 @@ TEST_P(TmaCircularBufferingTest, Reduction) { testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); } +TEST_P(TmaCircularBufferingTest, OuterReduction) { + NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + + if (std::holds_alternative(circular_buffer_type)) { + GTEST_SKIP() + << "This test uses block reduce, which implies block sync, " + << "which can cause deadlock when combined with warp specialization. " + } + + std::unique_ptr fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(2); + fusion->addInput(tv0); + + TensorView* tv1 = sum(tv0, {0}); + fusion->addOutput(tv1); + + TensorView* tv2 = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv2->setMemoryType(MemoryType::Shared); + + TensorView* reference = tv1; + + constexpr int64_t tile_size = 256; + + // [M, N] -> [M, N/bid, bid] + reference->split(1, tile_size); + + TransformPropagatorWithCheck propagator(reference); + MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); + + // Parallelize + reference->axis(1)->parallelize(ParallelType::BIDx); + reference->axis(2)->parallelize(ParallelType::TIDx); + tv2->axis(1)->parallelize(ParallelType::BIDx); + tv2->axis(2)->parallelize(ParallelType::Bulk); + + // InlineMost automatically handles vectorize and tma dimensions + inlineMost(); + + // Circular Buffer with TMA loads + tv2->circularBuffer( + number_of_stages, prefetch_distance, circular_buffer_type); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({tensor_outer_dim, tensor_inner_dim}, options); + at::Tensor t1 = sum(t0, {0}); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + std::vector cg_outputs = ke.run({t0}); + compare(tensor_outer_dim, cg_outputs.front(), t1); + testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); +} + TEST_P(TmaCircularBufferingTest, Persistent) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); From 7a20b027cebb78c32f94c78deaa11bbed933677f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:06:47 -0800 Subject: [PATCH 23/38] save --- tests/cpp/test_circular_buffering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index a474641f81f..57dac5e36d9 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1363,7 +1363,7 @@ TEST_P(TmaCircularBufferingTest, InnerReduction) { if (std::holds_alternative(circular_buffer_type)) { GTEST_SKIP() << "This test uses block reduce, which implies block sync, " - << "which can cause deadlock when combined with warp specialization. " + << "which can cause deadlock when combined with warp specialization."; } std::unique_ptr fusion = std::make_unique(); From 05a31283eb06539474cfd306095b3e67dfa8ad96 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:07:54 -0800 Subject: [PATCH 24/38] save --- tests/cpp/test_circular_buffering.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 57dac5e36d9..00e20986e5d 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1427,12 +1427,6 @@ TEST_P(TmaCircularBufferingTest, InnerReduction) { TEST_P(TmaCircularBufferingTest, OuterReduction) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (std::holds_alternative(circular_buffer_type)) { - GTEST_SKIP() - << "This test uses block reduce, which implies block sync, " - << "which can cause deadlock when combined with warp specialization. " - } - std::unique_ptr fusion = std::make_unique(); FusionGuard fg(fusion.get()); From f839382303820b4b469c0aa8ab9b5f2f45bd7925 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:20:23 -0800 Subject: [PATCH 25/38] save --- tests/cpp/test_circular_buffering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 00e20986e5d..2a518f28dc5 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1455,8 +1455,8 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { tv2->axis(1)->parallelize(ParallelType::BIDx); tv2->axis(2)->parallelize(ParallelType::Bulk); - // InlineMost automatically handles vectorize and tma dimensions - inlineMost(); + inlineAllAt(reference, /*pos=*/1); + // TODO: fix inlineMost(); // Circular Buffer with TMA loads tv2->circularBuffer( From eaf1ff5bcd7a958ba5882059858cc47ef293d2ac Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:21:36 -0800 Subject: [PATCH 26/38] save --- tests/cpp/test_circular_buffering.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 2a518f28dc5..a350fde232d 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1445,6 +1445,8 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { // [M, N] -> [M, N/bid, bid] reference->split(1, tile_size); + // [M, N/bid, bid] -> [N/bid, M, bid] + reference->reorder({{1, 0}}); TransformPropagatorWithCheck propagator(reference); MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); @@ -1455,8 +1457,7 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { tv2->axis(1)->parallelize(ParallelType::BIDx); tv2->axis(2)->parallelize(ParallelType::Bulk); - inlineAllAt(reference, /*pos=*/1); - // TODO: fix inlineMost(); + inlineMost(); // Circular Buffer with TMA loads tv2->circularBuffer( From 0cc170603fc3632899f721be6cbdff075c47e4a6 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:24:40 -0800 Subject: [PATCH 27/38] save --- tests/cpp/test_circular_buffering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index a350fde232d..790581cd224 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1452,9 +1452,9 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { MaxLogicalDomainInfoSpanningTree(reference).traverse(&propagator); // Parallelize - reference->axis(1)->parallelize(ParallelType::BIDx); + reference->axis(0)->parallelize(ParallelType::BIDx); reference->axis(2)->parallelize(ParallelType::TIDx); - tv2->axis(1)->parallelize(ParallelType::BIDx); + tv2->axis(0)->parallelize(ParallelType::BIDx); tv2->axis(2)->parallelize(ParallelType::Bulk); inlineMost(); From 70075e957bffd894661fc7290a5268784ae6b973 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:32:02 -0800 Subject: [PATCH 28/38] save --- tests/cpp/test_circular_buffering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 790581cd224..8cc255e9ec9 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1472,7 +1472,7 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { std::vector cg_outputs = ke.run({t0}); compare(tensor_outer_dim, cg_outputs.front(), t1); - testValidate(fusion.get(), cg_outputs, {t0}, {t1}, __LINE__, __FILE__); + EXPECT_EQ(at::allclose(cg_outputs.front(), t1, 1e-3), true); } TEST_P(TmaCircularBufferingTest, Persistent) { From 32ea65e0c648e9b7f00b726b642941103118d45f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:37:27 -0800 Subject: [PATCH 29/38] save --- tests/cpp/test_circular_buffering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 8cc255e9ec9..10a666dbdd4 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1472,7 +1472,7 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { std::vector cg_outputs = ke.run({t0}); compare(tensor_outer_dim, cg_outputs.front(), t1); - EXPECT_EQ(at::allclose(cg_outputs.front(), t1, 1e-3), true); + EXPECT_EQ(at::allclose(cg_outputs.front(), t1, 1e-3, 1e-3), true); } TEST_P(TmaCircularBufferingTest, Persistent) { From b34cf0559b81bba04152b62f25b01c0a27c94794 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:39:39 -0800 Subject: [PATCH 30/38] save --- tests/cpp/test_circular_buffering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 10a666dbdd4..6ad5bf71871 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1471,7 +1471,7 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { ke.compile(fusion.get(), {t0}); std::vector cg_outputs = ke.run({t0}); - compare(tensor_outer_dim, cg_outputs.front(), t1); + // compare(tensor_outer_dim, cg_outputs.front(), t1); EXPECT_EQ(at::allclose(cg_outputs.front(), t1, 1e-3, 1e-3), true); } From 807122b29615a0ce7f9a9ee2620e5e49600a7b2f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Mon, 2 Dec 2024 16:42:41 -0800 Subject: [PATCH 31/38] save --- tests/cpp/test_circular_buffering.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 6ad5bf71871..b149613e50b 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1478,6 +1478,13 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { TEST_P(TmaCircularBufferingTest, Persistent) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); + if (std::holds_alternative(circular_buffer_type)) { + GTEST_SKIP() + << "This test uses block reduce and block broadcast, " + << "which implies block sync, " + << "which can cause deadlock when combined with warp specialization."; + } + constexpr at::ScalarType dtype = at::ScalarType::Float; constexpr int64_t correction = 0; constexpr int64_t reduction_axis = 1; From 9f13ce4dd3c31adf2b654e91701fead846a7f83c Mon Sep 17 00:00:00 2001 From: Xiang Date: Tue, 3 Dec 2024 13:32:38 -0800 Subject: [PATCH 32/38] save --- csrc/device_lower/pass/predicate.cpp | 3 +++ csrc/ir/interface_nodes.h | 17 +++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 97ed99f86c4..adeb0e0888a 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -256,6 +256,9 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { Val* first_warp = IrBuilder::ltExpr( NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { + // If we are in a load warp for the warp specialization loop that has + // specialization on `pt`, then pt is already predicated by the + // warp-dispatch if-then-else, we should not predicate it again here. bool in_load_warp_for_pt = load_warp_loop_it != for_loops_.end() && std::get( GpuLower::current() diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 67e112fd643..6a115ebea35 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -113,7 +113,7 @@ class TVDomainGuard; // // /load 0;\ \. -// / load 1; [prefetch = 3] | [prologue] +// / load 1; [prefetch = 3] | [prefetching] // [stage] load 2;/ /' // [ = 6 ] load 3; wait load 0; compute 0; \. // \ load 4; wait load 1; compute 1; | @@ -123,7 +123,7 @@ class TVDomainGuard; // load 2; wait load 5; compute 5; wait compute 3; | // load 3; wait load 0; compute 0; wait compute 4; | // load 4; wait load 1; compute 1; wait compute 5; | [main] -// load 5; wait load 2; compute 2; wait compute 0; | [loop] +// load 5; wait load 2; compute 2; wait compute 0; | // .................................................. | // .................................................. | // .................................................. | @@ -132,7 +132,7 @@ class TVDomainGuard; // load ; wait load ; compute ; wait compute ; | // load ; wait load ; compute ; /' // /wait load ; compute ; \. -// [same number as prefetch] wait load ; compute ; | [epilogue] +// [same number as prefetch] wait load ; compute ; | [draining] // \wait load ; compute ; wait all computes; /' // clang-format on @@ -142,7 +142,13 @@ class TVDomainGuard; // load pipeline depth = prefetch + 1 // compute pipeline depth = stage - prefetch // -// The above timeline can be implemented as the following loop structure: +// There are two ways to implement the above timeline: pipelined, and +// warp-specialization. +// +// In the pipelined way, the prefetching stage is implemented as a prologue +// loop, and main stage is implemented as a main loop, and the draining stage is +// implemented as an epilogue loop. That is, we will have the following loop +// structure: // // Prologue loop: // for i in range(prefetch): @@ -166,6 +172,9 @@ class TVDomainGuard; // stage - prefetch - 1 iterations and last iteration of the main loop is // redundant. We can remove them to further optimize the performance, but // we decide to keep them for simplicity. +// +// In the warp-specialized approach, we will use different warp/warp-group +// that struct Pipelined { bool uses_mbarrier_for_war = false; From 258c31d57c3d62be75a43f3b2dda299335be1687 Mon Sep 17 00:00:00 2001 From: Xiang Date: Tue, 3 Dec 2024 14:06:19 -0800 Subject: [PATCH 33/38] save --- csrc/ir/interface_nodes.h | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 6a115ebea35..0ca24556789 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -154,13 +154,25 @@ class TVDomainGuard; // for i in range(prefetch): // load data[i] to buffer[i] // -// Main loop: +// Main loop (using syncthreads to avoid WAR harzard): // for i in range(data.size - prefetch): // load data[i + prefetch] to buffer[(i + prefetch) % stage] // wait buffer[i % stage] to be ready // compute buffer[i % stage] // wait until the first compute in the queue is done // (i.e. stage - prefetch - 1 in flight computes remaining) +// __syncthreads(); +// +// Main loop (using mbarrier to avoid WAR harzard): +// for i in range(data.size - prefetch): +// wait buffer[(i + prefetch) % stage] to be empty +// load data[i + prefetch] to buffer[(i + prefetch) % stage] +// wait buffer[i % stage] to be ready +// compute buffer[i % stage] +// wait until the first compute in the queue is done +// (i.e. stage - prefetch - 1 in flight computes remaining) +// signal that buffer (i + prefetch + 1) % stage is empty and ready to be +// loaded again // // Epilogue loop: // for i in range(data.size - prefetch, data.size): @@ -174,7 +186,8 @@ class TVDomainGuard; // we decide to keep them for simplicity. // // In the warp-specialized approach, we will use different warp/warp-group -// that +// for loading and consuming circular buffer. We will generate code like: + struct Pipelined { bool uses_mbarrier_for_war = false; From e2dfda7b8fbb6d3d76069e11874b5a274ec9f11d Mon Sep 17 00:00:00 2001 From: Xiang Date: Tue, 3 Dec 2024 15:29:24 -0800 Subject: [PATCH 34/38] comment --- csrc/ir/interface_nodes.h | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index 0ca24556789..4e53a6207cc 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -157,7 +157,7 @@ class TVDomainGuard; // Main loop (using syncthreads to avoid WAR harzard): // for i in range(data.size - prefetch): // load data[i + prefetch] to buffer[(i + prefetch) % stage] -// wait buffer[i % stage] to be ready +// wait buffer[i % stage] to be loaded // compute buffer[i % stage] // wait until the first compute in the queue is done // (i.e. stage - prefetch - 1 in flight computes remaining) @@ -167,7 +167,7 @@ class TVDomainGuard; // for i in range(data.size - prefetch): // wait buffer[(i + prefetch) % stage] to be empty // load data[i + prefetch] to buffer[(i + prefetch) % stage] -// wait buffer[i % stage] to be ready +// wait buffer[i % stage] to be loaded // compute buffer[i % stage] // wait until the first compute in the queue is done // (i.e. stage - prefetch - 1 in flight computes remaining) @@ -186,8 +186,28 @@ class TVDomainGuard; // we decide to keep them for simplicity. // // In the warp-specialized approach, we will use different warp/warp-group -// for loading and consuming circular buffer. We will generate code like: - +// for loading and computing. We will generate code like below (assuming warp +// specialized on TIDy): +// +// if (threadIdx.y == blockDim.y - 1) { +// // If we use warp specialization on TIDy, then the blockDim.y of the +// // kernel will be (whatever_value_inferred_from_schedule + 1), and the +// // last threadIdx.y will be used as load warp +// for i in range(data.size): +// wait buffer[i % stage] to be empty +// load data[i] to buffer[i % stage] +// } else { +// // Every threadIdx.y other than the last will be used for compute +// for i in range(prefetch + 1): +// signal that buffer i % stage is empty and ready to load +// for i in range(data.size): +// wait buffer[i % stage] to be loaded +// compute buffer[i % stage] +// wait until the first compute in the queue is done +// (i.e. stage - prefetch - 1 in flight computes remaining) +// signal that buffer (i + prefetch + 1) % stage is empty and ready to be +// loaded again +// } struct Pipelined { bool uses_mbarrier_for_war = false; @@ -206,8 +226,6 @@ inline std::ostream& operator<<(std::ostream& os, const Pipelined& pipelined) { return os << "Pipelined"; } -// For example, if `on` is TIDy, then will assign additional TIDy for cirular -// buffer loading. struct WarpSpecialized { ParallelType on; explicit WarpSpecialized(ParallelType on) : on(on) {} From 88f2e3ffc02d570be6411eda837fd499d79fa828 Mon Sep 17 00:00:00 2001 From: Xiang Date: Tue, 3 Dec 2024 15:41:24 -0800 Subject: [PATCH 35/38] doc --- csrc/device_lower/pass/allocation.cpp | 4 +++- csrc/parallel_dimension_map.cpp | 2 +- csrc/parallel_dimension_map.h | 19 ++++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 90f80a14f2c..bee61c46873 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -89,7 +89,9 @@ Expr* initializeMbarrier( // threads in the CTA. num_of_arrives = SimplifyingIrBuilder::maybeCastExpr( DataType::UInt32, - GpuLower::current()->parallelDimensionMap().getNumThreadsEachBlockIgnoringWarpSpecialization()); + GpuLower::current() + ->parallelDimensionMap() + .getNumComputeThreadsEachBlock()); } // Initialize mbarrier for each circular buffer stage. Use the thread diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 2792bd75e7d..1878fdda59b 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -182,7 +182,7 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const { return exact_types_.find(pt) != exact_types_.end(); } -Val* ParallelDimensionMap::getNumThreadsEachBlockIgnoringWarpSpecialization() const { +Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { Val* num_threads = FusionGuard::getCurFusion()->oneVal(); for (auto pt : kParallelTypeTIDs) { auto dim = getRaw(pt); diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 5fce9050eed..0e25182c12e 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -41,7 +41,19 @@ class ParallelDimensionMap { return dim_map_; } - Val* getNumThreadsEachBlockIgnoringWarpSpecialization() const; + //! Get the number of threads per each CTA used for computation. When there is + //! no warp specialization, the result is trivial: it is just the product of + //! parallel dimensions of TIDx, TIDy and TIDz. If we do have warp + //! specialization, this returns the number of threads used for computing. For + //! example, if we have a simple kernel warp specialized on TIDy and all the + //! TIDx parallelized IterDomains have extent 32, and all the TIDy + //! parallelized IterDomains have extent 16, and there is no TIDz + //! parallelization, then we will have: + //! blockDim = (x=32, y=17, z=1) + //! And this function will return (32 * 16) because the extra one for TIDy is + //! introduced by warp specialization and only used for loading circular + //! buffer tensors. + Val* getNumComputeThreadsEachBlock() const; bool has(ParallelType pt) const { return dim_map_.count(pt) > 0; @@ -52,6 +64,9 @@ class ParallelDimensionMap { //! multiple of the warp size. void adjustMappingsForWarpPadding(); + //! If we are doing warp specialization on pt, then we need to increase + //! the parallel dimension size of pt by one, where the extra one is used + //! as the load warp. In this case, pt becomes non-exact. void setWarpSpecializeOn(ParallelType pt); private: @@ -61,6 +76,8 @@ class ParallelDimensionMap { //! Set of parallel types whose dimensions are identified to be //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; + + //! Set of parallel types that we are doing warp specialization on std::unordered_set warp_specialized_types_; }; From 07e59defdd76f432584d3f76b67dc1c2b1f8c2c0 Mon Sep 17 00:00:00 2001 From: Xiang Date: Tue, 3 Dec 2024 16:43:13 -0800 Subject: [PATCH 36/38] error message --- tests/cpp/test_circular_buffering.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 6c3515814c0..f8cd27d1cfa 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1368,7 +1368,7 @@ TEST_P(TmaCircularBufferingTest, InnerReduction) { if (std::holds_alternative(circular_buffer_type)) { GTEST_SKIP() - << "This test uses block reduce, which implies block sync, " + << "This test uses block reduce, which uses hard-coded blockDim, " << "which can cause deadlock when combined with warp specialization."; } @@ -1489,7 +1489,7 @@ TEST_P(TmaCircularBufferingTest, Persistent) { if (std::holds_alternative(circular_buffer_type)) { GTEST_SKIP() << "This test uses block reduce and block broadcast, " - << "which implies block sync, " + << "which has hard-coded blockDim, " << "which can cause deadlock when combined with warp specialization."; } From 8eb0dc0754c5557853fe91bc9462eff375af7b79 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 4 Dec 2024 10:36:49 -0800 Subject: [PATCH 37/38] Warp specialization on x (#3525) --- csrc/device_lower/pass/predicate.cpp | 41 ++++++++++++++++----------- tests/cpp/test_circular_buffering.cpp | 5 +++- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index adeb0e0888a..afff481fb94 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -251,29 +251,36 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { return fl->circularBufferLoopStage() == CircularBufferLoopStage::LoadWarp; }); + ParallelType load_warp_on = ParallelType::Serial; + if (load_warp_loop_it != for_loops_.end()) { + load_warp_on = std::get( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor( + (*load_warp_loop_it)->iter_domain()) + .type) + .on; + } + // If we are in a load warp, then the warp-dispatching IfThenElse + // already selects on `load_warp_on`, so we should not generate + // predicates for it here. const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); - Val* first_warp = IrBuilder::ltExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + Val* conditional = load_warp_on == ParallelType::TIDx + ? pred->fusion()->trueVal() + : SimplifyingIrBuilder::logicalAndExpr( + elect_sync_val, + IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), + warp_size)); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { - // If we are in a load warp for the warp specialization loop that has - // specialization on `pt`, then pt is already predicated by the - // warp-dispatch if-then-else, we should not predicate it again here. - bool in_load_warp_for_pt = load_warp_loop_it != for_loops_.end() && - std::get( - GpuLower::current() - ->circularBufferInfo() - .getCircularBufferOptionsFor( - (*load_warp_loop_it)->iter_domain()) - .type) - .on == pt; - if (pdim_map.has(pt) && !in_load_warp_for_pt) { - first_warp = SimplifyingIrBuilder::logicalAndExpr( - first_warp, + if (pdim_map.has(pt) && load_warp_on != pt) { + conditional = SimplifyingIrBuilder::logicalAndExpr( + conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); } } - return SimplifyingIrBuilder::logicalAndExpr(first_warp, elect_sync_val); + return conditional; } default: break; diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index f8cd27d1cfa..80776155860 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1865,7 +1865,10 @@ auto tmaCircularBufferingParams() { // https://en.wikipedia.org/wiki/Lehmer_random_number_generator uint32_t lcg_parkmiller = 1; const std::vector all_types{ - Pipelined(false), Pipelined(true), WarpSpecialized(ParallelType::TIDy)}; + Pipelined(false), + Pipelined(true), + WarpSpecialized(ParallelType::TIDx), + WarpSpecialized(ParallelType::TIDy)}; std::vector values; for (int64_t i : {2, 4}) { for (int64_t j : c10::irange(-i, i)) { From e33f8e24969221942ff1f20233bcfa7b564d285c Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Wed, 4 Dec 2024 10:43:04 -0800 Subject: [PATCH 38/38] Update tests/cpp/test_circular_buffering.cpp Co-authored-by: Ryan Spring --- tests/cpp/test_circular_buffering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 80776155860..725f2c128b9 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1267,7 +1267,7 @@ TEST_P(TmaCircularBufferingTest, Pointwise) { number_of_stages, prefetch_distance, circular_buffer_type); // Circular Buffer with set operation. - // Note that in order to use warp specialization, all circilar buffers must be + // Note that in order to use warp specialization, all circular buffers must be // loaded by TMA, so for this test we disable circular buffering of set op if // we are testing warp specialization. if (!std::holds_alternative(circular_buffer_type)) {