diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index 2e00308f233..4b4d962eace 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -42,10 +42,54 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { using kir::ExprMutator::handle; + // The ElectSync predicate expects a single thread to run operations within + // If-Then-Else. Any TensorView with thread parallelization is incompatible + // with this If-Then-Else because it can create a conflicting predicate. + void checkElectSyncCompatibility(Expr* expr) { + NVF_CHECK(expr->predicate()->predicate_type() == PredicateType::ElectSync); + NVF_ERROR(expr->isA()); + + // Check all the expressions in the scope + auto check_scope_compatibility = [](Scope& scope) { + for (Expr* expr : scope.exprs()) { + // Thread predicates are generated based on the expression's outputs + for (Val* val : expr->outputs()) { + // short-circuit + if (!val->isA()) { + continue; + } + // Check that none of the IterDomains in TensorView are parallelized + // with a thread dimension like TIDx, TIDy, or TIDz. + TensorView* tv = val->as()->view(); + bool is_thread_parallelized = std::any_of( + tv->domain()->loop().begin(), + tv->domain()->loop().end(), + [](IterDomain* id) { return id->isThreadDim(); }); + NVF_ERROR( + !is_thread_parallelized, + "This thread-parallelized TensorView ", + tv->toString(), + " is incorrectly contained within a If-Then-Else with the ", + "ElectSync predicate."); + } + } + }; + + // Check the thenBody and elseBody of If-Then-Else + kir::IfThenElse* ite = expr->as(); + check_scope_compatibility(ite->thenBody()); + check_scope_compatibility(ite->elseBody()); + } + void dispatch(Expr* expr) final { if (expr != nullptr && expr->predicate() != nullptr) { // Replace expr predicate with bool conditional auto conditional = generateConditional(expr->predicate()); + + if (expr->predicate()->predicate_type() == PredicateType::ElectSync) { + checkElectSyncCompatibility(expr); + } + if (expr->predicate()->predicate_type() == PredicateType::Vectorize) { if (expr->isA()) { // TODO: This logic doesn't seem to fit well here, for unswitch the @@ -209,6 +253,27 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { // here. return IrBuilder::create(true, DataType::Bool); } + case PredicateType::ElectSync: { + Val* zero = IrBuilder::create(0L, PrimDataType::UInt); + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt); + Val* full_mask_val = + IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); + + Val* elect_sync_val = IrBuilder::create(PrimDataType::Bool); + IrBuilder::create( + UnaryOpType::ElectSync, elect_sync_val, full_mask_val); + + Val* first_warp = IrBuilder::logicalAndExpr( + IrBuilder::logicalAndExpr( + IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), + warp_size), + IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDy), zero)), + IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDz), zero)); + return IrBuilder::logicalAndExpr(first_warp, elect_sync_val); + } default: break; } diff --git a/csrc/fusion_executor/executor.cpp b/csrc/fusion_executor/executor.cpp index 00c77005e0e..148a8aeae7a 100644 --- a/csrc/fusion_executor/executor.cpp +++ b/csrc/fusion_executor/executor.cpp @@ -828,6 +828,16 @@ void FusionExecutor::initializeExecutorEntry( executor_utils::validateCircularBuffering(kernel(), expr_eval); + // Check that a full warp exists in blockDim.x if the kernel contains + // ElectSync predicate. + constexpr int64_t warp_size = 32; + NVF_ERROR( + !kernel()->summary().has_elect_sync_predicate || + launch_params.bdimx() >= warp_size, + "This cuda kernel contains electSync predicate. " + "Expected blockDim.x >= 32 but found ", + launch_params.bdimx()); + std::vector output_info; if (outputs.empty()) { diff --git a/csrc/ir/base_nodes.cpp b/csrc/ir/base_nodes.cpp index 43c2b28b01c..96f5b6762a6 100644 --- a/csrc/ir/base_nodes.cpp +++ b/csrc/ir/base_nodes.cpp @@ -196,6 +196,14 @@ bool Val::isConstScalar() const { if (!isScalar()) { return false; } + // elect.sync ptx picks a leader thread from membermask. + // It cannot be evaluated at compile-time. + if (Expr* def = definition()) { + if (def->isA() && + def->as()->getUnaryOpType() == UnaryOpType::ElectSync) { + return false; + } + } return ir_utils::dependenciesSatisfied(this); } diff --git a/csrc/kernel.cpp b/csrc/kernel.cpp index 23c332551dc..caf23e9df6a 100644 --- a/csrc/kernel.cpp +++ b/csrc/kernel.cpp @@ -221,6 +221,15 @@ class KernelIrScanner : private IrVisitor { summary_.has_grid_broadcasts || parallel_types.hasBID(); } + void handle(IfThenElse* ite) final { + // Do we have any elect sync predicates? + if (ite->predicate()->predicate_type() == PredicateType::ElectSync) { + summary_.has_elect_sync_predicate = true; + } + // Run default handle + IrVisitor::handle(ite); + } + private: size_t max_smem_type_size_ = 0; KernelSummary summary_; diff --git a/csrc/kernel.h b/csrc/kernel.h index 3aff8c7749b..491aea30569 100644 --- a/csrc/kernel.h +++ b/csrc/kernel.h @@ -125,6 +125,11 @@ struct KernelSummary { //! Track Circular Buffer TensorViews CircularBufferInfo circular_buffer_info; + + //! Track if there are ElectSync predicates in this Kernel. + //! Reason: At runtime, we check that at least a single warp along TIDx axis + //! exists. + bool has_elect_sync_predicate = false; }; class KernelPerformanceProfile { diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 1583d1321b9..076b95f5591 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -83,7 +83,8 @@ class Predicate final : public Val { const Expr* expr() const { NVF_ERROR( ptype_ != PredicateType::Unswitch && - ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual); + ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual && + ptype_ != PredicateType::ElectSync); return expr_; } @@ -91,7 +92,8 @@ class Predicate final : public Val { NVF_ERROR( ptype_ == PredicateType::Inline || ptype_ == PredicateType::Misaligned || - ptype_ == PredicateType::ReductionWrite); + ptype_ == PredicateType::ReductionWrite || + ptype_ == PredicateType::ElectSync); return thread_pred_; } diff --git a/csrc/type.cpp b/csrc/type.cpp index 2169f97ef13..2c5e75f6a6a 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -468,6 +468,8 @@ static const char* unary_op_type2string(UnaryOpType t) { return "std::imag"; case UnaryOpType::ToUnsignedSmemAddr: return "toSmem"; + case UnaryOpType::ElectSync: + return "Hopper::electSync"; case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8: return "Turing::adjustPartialLdMatrixAddrInTuring<8>"; case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16: diff --git a/csrc/type.h b/csrc/type.h index 142158e4782..b4ff32c8607 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -46,6 +46,7 @@ enum class ValType { // Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag // ReductionWrite - Same as Inline but without reduction axes // LoopRotation - Predicate added by loop rotation, currently always true. +// ElectSync - Select a single thread to launch asynchronous operations. enum class PredicateType { Manual, Inline, @@ -53,7 +54,8 @@ enum class PredicateType { Vectorize, Misaligned, ReductionWrite, - LoopRotation + LoopRotation, + ElectSync }; // Index type is a convenience type that may be a 64 or 32 signed integer. @@ -589,6 +591,7 @@ enum class UnaryOpType { IsReal, // Special unary ops + ElectSync, ToUnsignedSmemAddr, AdjustPartialLdMatrixAddrInTuring8, AdjustPartialLdMatrixAddrInTuring16 diff --git a/doc/dev/tma.md b/doc/dev/tma.md index 373e4a008f6..b4fc12b31c7 100644 --- a/doc/dev/tma.md +++ b/doc/dev/tma.md @@ -349,6 +349,11 @@ the TMA domain can be completely inferred from the schedule. > We do not have validation on shared memory schedule yet. > If you scheduled something invalid, likely you will see misaligned address error or silent wrong result. +> [!WARNING] +> When using circular buffering with TMA, a single thread is select to launch the TMA load and mbarrier operations. +> In this case, we cannot apply any block parallelization to the consumer TensorView, which will create a thread predicate. +> A compile-time error will occur if you apply circular buffering and block parallelization together. + #### Data swizzle So far we have been ignoring the shared memory swizzle feature of TMA and diff --git a/runtime/memory.cu b/runtime/memory.cu index 969420fe95c..3a66c26b318 100644 --- a/runtime/memory.cu +++ b/runtime/memory.cu @@ -69,7 +69,7 @@ namespace Hopper { // // Document Reference: // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-elect-sync -__device__ inline bool elect_sync(const uint32_t& membermask) { +__device__ inline bool electSync(const uint32_t& membermask) { uint32_t is_elected; asm volatile( "{\n\t .reg .pred P_OUT; \n\t"