Skip to content

Commit

Permalink
Merge branch 'main' into resize_scheduler_initial_version
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Dec 13, 2024
2 parents 52acb42 + 201a636 commit 6363298
Show file tree
Hide file tree
Showing 15 changed files with 619 additions and 106 deletions.
25 changes: 24 additions & 1 deletion csrc/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,9 @@ void IterDomainGraph::build(Fusion* fusion) {

// Grab all the logical ids.
for (auto consumer_tv : all_consumer_tvs) {
auto exprs = StmtSort::getExprsTo(
auto exprs = StmtSort::getExprsBetween(
{consumer_tv->getMaybeRootDomain().begin(),
consumer_tv->getMaybeRootDomain().end()},
{consumer_tv->getLogicalDomain().begin(),
consumer_tv->getLogicalDomain().end()});
for (auto expr : exprs) {
Expand Down Expand Up @@ -663,6 +665,20 @@ void IterDomainGraph::build(Fusion* fusion) {
continue;
}

// logical_id_uses are guaranteed to be a valid expr, but
// first_logical_id->definition() may not be part of the valid
// exprs
if (!prop_forward) {
if (std::any_of(
first_expr->inputs().begin(),
first_expr->inputs().end(),
[&](Val* id_input) {
return !all_ids_.has(id_input->as<IterDomain>());
})) {
continue;
}
}

if (visited_exprs.find(first_expr) != visited_exprs.end()) {
continue;
}
Expand Down Expand Up @@ -1282,6 +1298,13 @@ void ComputeAtMap::buildUniqueExactExprMaps() {
if (id->definition() != nullptr) {
auto id_inputs =
ir_utils::filterByType<IterDomain>(id->definition()->inputs());
// If any input ID is not included in the map, this definition
// should not be included either.
if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) {
return !idExistsInMap(id_input);
})) {
continue;
}
if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) {
return disjoint_set_shared_ptr->has(id_input);
})) {
Expand Down
6 changes: 5 additions & 1 deletion csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,11 @@ std::array<UnitDim, 2> getMmaLayout(const MmaOp* expr) {

auto out_tv = ir_utils::getTv(expr->out());
IterDomain* reduction_id = nullptr;
for (auto id : out_tv->getLogicalDomain()) {
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. In the split-k case, use the root domain for the
// mma layout because the k dimension is divided into two iterDomains in the
// logical domain.
for (auto id : out_tv->getMaybeRootDomain()) {
if (id->isReduction()) {
reduction_id = id;
break;
Expand Down
36 changes: 18 additions & 18 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Predicate final : public Val {

std::string toString(int indent_size = 0) const override;

NVF_API std::string toInlineString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

PredicateType predicate_type() const {
return ptype_;
Expand Down Expand Up @@ -148,7 +148,7 @@ class Predicate final : public Val {
Val* value_ = nullptr;
};

class NVF_API TensorIndex final : public Val {
class TensorIndex final : public Val {
public:
TensorIndex(
IrBuilderPasskey,
Expand Down Expand Up @@ -252,7 +252,7 @@ class Asm final : public Expr {
//! is required as an intermediate within a kernel. The extent is the expression
//! of the size of the buffer that is generated from the TensorView that
//! describes the output of an operation.
class NVF_API Allocate final : public Expr {
class Allocate final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -385,7 +385,7 @@ class NVF_API Allocate final : public Expr {
//
// TODO(kir): change name to SyncThreads as we could have other barriers.
//
class NVF_API BlockSync final : public Expr {
class BlockSync final : public Expr {
public:
using Expr::Expr;

Expand All @@ -408,7 +408,7 @@ class NVF_API BlockSync final : public Expr {

// Synchronize all blocks in device, implies cooperative group launch is
// required.
class NVF_API GridSync final : public Expr {
class GridSync final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -436,7 +436,7 @@ class NVF_API GridSync final : public Expr {
};

// PTX: fence.proxy.async
class NVF_API FenceAsyncProxy final : public Expr {
class FenceAsyncProxy final : public Expr {
public:
using Expr::Expr;

Expand All @@ -453,7 +453,7 @@ class NVF_API FenceAsyncProxy final : public Expr {
};

// PTX: wgmma.fence.sync.aligned
class NVF_API WgMmaFence final : public Expr {
class WgMmaFence final : public Expr {
public:
using Expr::Expr;

Expand All @@ -469,7 +469,7 @@ class NVF_API WgMmaFence final : public Expr {
std::string toInlineString(int indent_size = 0) const override;
};

class NVF_API MBarrierInit final : public Expr {
class MBarrierInit final : public Expr {
public:
using Expr::Expr;
explicit MBarrierInit(
Expand All @@ -495,7 +495,7 @@ class NVF_API MBarrierInit final : public Expr {
}
};

class NVF_API MBarrierInvalidate final : public Expr {
class MBarrierInvalidate final : public Expr {
public:
using Expr::Expr;
explicit MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier);
Expand All @@ -514,7 +514,7 @@ class NVF_API MBarrierInvalidate final : public Expr {
}
};

class NVF_API MBarrierArrive final : public Expr {
class MBarrierArrive final : public Expr {
public:
using Expr::Expr;
explicit MBarrierArrive(IrBuilderPasskey passkey, Val* state, Val* mbarrier);
Expand Down Expand Up @@ -544,7 +544,7 @@ class NVF_API MBarrierArrive final : public Expr {
// This is usually used to specify the number of bytes that will be
// transferred for cp.async and cp.async.bulk, so that future mbarrier.wait
// can wait for the completion of the transfer.
class NVF_API MBarrierArriveExpectTx final : public Expr {
class MBarrierArriveExpectTx final : public Expr {
public:
using Expr::Expr;
explicit MBarrierArriveExpectTx(
Expand Down Expand Up @@ -578,7 +578,7 @@ class NVF_API MBarrierArriveExpectTx final : public Expr {
}
};

class NVF_API MBarrierWait final : public Expr {
class MBarrierWait final : public Expr {
public:
using Expr::Expr;
explicit MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state);
Expand All @@ -601,7 +601,7 @@ class NVF_API MBarrierWait final : public Expr {
}
};

class NVF_API MBarrierWaitParity final : public Expr {
class MBarrierWaitParity final : public Expr {
public:
using Expr::Expr;
explicit MBarrierWaitParity(
Expand Down Expand Up @@ -796,7 +796,7 @@ class UpdateMagicZero final : public Expr {
//!
//! TODO(kir): this is not a real expression
//!
class NVF_API IfThenElse final : public Expr {
class IfThenElse final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -915,7 +915,7 @@ class GridReduction final : public ReductionOp {
}
};

class NVF_API GroupedGridReduction final : public GroupedReductionOp {
class GroupedGridReduction final : public GroupedReductionOp {
public:
using GroupedReductionOp::GroupedReductionOp;

Expand Down Expand Up @@ -1006,7 +1006,7 @@ class NVF_API GroupedGridReduction final : public GroupedReductionOp {
//!
//! This node provides KernelExecutor the information it needs to allocate the
//! broadcast and sync buffers.
class NVF_API GridBroadcast final : public Expr {
class GridBroadcast final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -1117,7 +1117,7 @@ class GridWelford final : public Expr {
}
};

class NVF_API GroupedGridWelford final : public GroupedWelfordOp {
class GroupedGridWelford final : public GroupedWelfordOp {
public:
using GroupedWelfordOp::GroupedWelfordOp;

Expand Down Expand Up @@ -1211,7 +1211,7 @@ class NVF_API GroupedGridWelford final : public GroupedWelfordOp {

//! Represents a WelfordOp with the division by count is hoisted out
//! of an innermost loop
class NVF_API VectorizedWelfordOp final : public WelfordOp {
class VectorizedWelfordOp final : public WelfordOp {
public:
using WelfordOp::WelfordOp;

Expand Down
10 changes: 8 additions & 2 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,11 +1302,17 @@ void AmpereMultipleMatmulScheduler::setUpCircularBuffering() {

for (TensorView* acw_smem : acw_smems_) {
acw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage);
params_->circular_buffer_options.smem_circular_buffer_stage,
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage);
params_->circular_buffer_options.smem_circular_buffer_stage,
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
}
}

Expand Down
Loading

0 comments on commit 6363298

Please sign in to comment.