Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create ElectSync predicate type #2923

Merged
merged 7 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,54 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {

using kir::ExprMutator::handle;

// The ElectSync predicate expects a single thread to run operations within
// If-Then-Else. Any TensorView with thread parallelization is incompatible
// with this If-Then-Else because it can create a conflicting predicate.
void checkElectSyncCompatibility(Expr* expr) {
NVF_CHECK(expr->predicate()->predicate_type() == PredicateType::ElectSync);
NVF_ERROR(expr->isA<kir::IfThenElse>());

// Check all the expressions in the scope
auto check_scope_compatibility = [](Scope& scope) {
for (Expr* expr : scope.exprs()) {
// Thread predicates are generated based on the expression's outputs
for (Val* val : expr->outputs()) {
// short-circuit
if (!val->isA<kir::TensorIndex>()) {
continue;
}
// Check that none of the IterDomains in TensorView are parallelized
// with a thread dimension like TIDx, TIDy, or TIDz.
TensorView* tv = val->as<kir::TensorIndex>()->view();
bool is_thread_parallelized = std::any_of(
tv->domain()->loop().begin(),
tv->domain()->loop().end(),
[](IterDomain* id) { return id->isThreadDim(); });
NVF_ERROR(
!is_thread_parallelized,
"This thread-parallelized TensorView ",
tv->toString(),
" is incorrectly contained within a If-Then-Else with the ",
"ElectSync predicate.");
}
}
};

// Check the thenBody and elseBody of If-Then-Else
kir::IfThenElse* ite = expr->as<kir::IfThenElse>();
check_scope_compatibility(ite->thenBody());
check_scope_compatibility(ite->elseBody());
}

void dispatch(Expr* expr) final {
if (expr != nullptr && expr->predicate() != nullptr) {
// Replace expr predicate with bool conditional
auto conditional = generateConditional(expr->predicate());

if (expr->predicate()->predicate_type() == PredicateType::ElectSync) {
checkElectSyncCompatibility(expr);
}

if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
if (expr->isA<kir::IfThenElse>()) {
// TODO: This logic doesn't seem to fit well here, for unswitch the
Expand Down Expand Up @@ -209,6 +253,27 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
// here.
return IrBuilder::create<Val>(true, DataType::Bool);
}
case PredicateType::ElectSync: {
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt);
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt);
Val* full_mask_val =
IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);

Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);

Val* first_warp = IrBuilder::logicalAndExpr(
IrBuilder::logicalAndExpr(
IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx),
warp_size),
IrBuilder::eqExpr(
NamedScalar::getParallelIndex(ParallelType::TIDy), zero)),
IrBuilder::eqExpr(
NamedScalar::getParallelIndex(ParallelType::TIDz), zero));
return IrBuilder::logicalAndExpr(first_warp, elect_sync_val);
}
default:
break;
}
Expand Down
10 changes: 10 additions & 0 deletions csrc/fusion_executor/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,16 @@ void FusionExecutor::initializeExecutorEntry(

executor_utils::validateCircularBuffering(kernel(), expr_eval);

// Check that a full warp exists in blockDim.x if the kernel contains
// ElectSync predicate.
constexpr int64_t warp_size = 32;
NVF_ERROR(
!kernel()->summary().has_elect_sync_predicate ||
launch_params.bdimx() >= warp_size,
"This cuda kernel contains electSync predicate. "
"Expected blockDim.x >= 32 but found ",
launch_params.bdimx());

std::vector<GlobalBufferInfo> output_info;

if (outputs.empty()) {
Expand Down
8 changes: 8 additions & 0 deletions csrc/ir/base_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ bool Val::isConstScalar() const {
if (!isScalar()) {
return false;
}
// elect.sync ptx picks a leader thread from membermask.
// It cannot be evaluated at compile-time.
if (Expr* def = definition()) {
if (def->isA<UnaryOp>() &&
def->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::ElectSync) {
return false;
}
}
return ir_utils::dependenciesSatisfied(this);
}

Expand Down
9 changes: 9 additions & 0 deletions csrc/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ class KernelIrScanner : private IrVisitor {
summary_.has_grid_broadcasts || parallel_types.hasBID();
}

void handle(IfThenElse* ite) final {
// Do we have any elect sync predicates?
if (ite->predicate()->predicate_type() == PredicateType::ElectSync) {
summary_.has_elect_sync_predicate = true;
}
// Run default handle
IrVisitor::handle(ite);
}

private:
size_t max_smem_type_size_ = 0;
KernelSummary summary_;
Expand Down
5 changes: 5 additions & 0 deletions csrc/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ struct KernelSummary {

//! Track Circular Buffer TensorViews
CircularBufferInfo circular_buffer_info;

//! Track if there are ElectSync predicates in this Kernel.
//! Reason: At runtime, we check that at least a single warp along TIDx axis
//! exists.
bool has_elect_sync_predicate = false;
};

class KernelPerformanceProfile {
Expand Down
6 changes: 4 additions & 2 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ class Predicate final : public Val {
const Expr* expr() const {
NVF_ERROR(
ptype_ != PredicateType::Unswitch &&
ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual);
ptype_ != PredicateType::Vectorize && ptype_ != PredicateType::Manual &&
ptype_ != PredicateType::ElectSync);
return expr_;
}

Val* thread_pred() const {
NVF_ERROR(
ptype_ == PredicateType::Inline ||
ptype_ == PredicateType::Misaligned ||
ptype_ == PredicateType::ReductionWrite);
ptype_ == PredicateType::ReductionWrite ||
ptype_ == PredicateType::ElectSync);
return thread_pred_;
}

Expand Down
2 changes: 2 additions & 0 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ static const char* unary_op_type2string(UnaryOpType t) {
return "std::imag";
case UnaryOpType::ToUnsignedSmemAddr:
return "toSmem";
case UnaryOpType::ElectSync:
return "Hopper::electSync";
case UnaryOpType::AdjustPartialLdMatrixAddrInTuring8:
return "Turing::adjustPartialLdMatrixAddrInTuring<8>";
case UnaryOpType::AdjustPartialLdMatrixAddrInTuring16:
Expand Down
5 changes: 4 additions & 1 deletion csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,16 @@ enum class ValType {
// Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
// ReductionWrite - Same as Inline but without reduction axes
// LoopRotation - Predicate added by loop rotation, currently always true.
// ElectSync - Select a single thread to launch asynchronous operations.
enum class PredicateType {
Manual,
Inline,
Unswitch,
Vectorize,
Misaligned,
ReductionWrite,
LoopRotation
LoopRotation,
ElectSync
};

// Index type is a convenience type that may be a 64 or 32 signed integer.
Expand Down Expand Up @@ -589,6 +591,7 @@ enum class UnaryOpType {
IsReal,

// Special unary ops
ElectSync,
ToUnsignedSmemAddr,
AdjustPartialLdMatrixAddrInTuring8,
AdjustPartialLdMatrixAddrInTuring16
Expand Down
5 changes: 5 additions & 0 deletions doc/dev/tma.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ the TMA domain can be completely inferred from the schedule.
> We do not have validation on shared memory schedule yet.
> If you scheduled something invalid, likely you will see misaligned address error or silent wrong result.

> [!WARNING]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps there should be a circular buffering section but I added the warning here for now.

> When using circular buffering with TMA, a single thread is select to launch the TMA load and mbarrier operations.
> In this case, we cannot apply any block parallelization to the consumer TensorView, which will create a thread predicate.
> A compile-time error will occur if you apply circular buffering and block parallelization together.

#### Data swizzle

So far we have been ignoring the shared memory swizzle feature of TMA and
Expand Down
2 changes: 1 addition & 1 deletion runtime/memory.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace Hopper {
//
// Document Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-elect-sync
__device__ inline bool elect_sync(const uint32_t& membermask) {
__device__ inline bool electSync(const uint32_t& membermask) {
uint32_t is_elected;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
Expand Down
Loading