Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Dec 12, 2024
1 parent 3b0cf99 commit e18741c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
4 changes: 4 additions & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,10 @@ std::array<UnitDim, 2> getMmaLayout(const MmaOp* expr) {

auto out_tv = ir_utils::getTv(expr->out());
IterDomain* reduction_id = nullptr;
// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. In the split-k case, use the root domain for the
// mma layout because the k dimension is divided into two iterDomains in the
// logical domain.
for (auto id : out_tv->getMaybeRootDomain()) {
if (id->isReduction()) {
reduction_id = id;
Expand Down
14 changes: 13 additions & 1 deletion csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,12 @@ TensorView* TensorView::rFactor(const std::vector<int64_t>& axes) {
"Error rfactoring ",
this,
" its definition is either a nullptr or not a reduction.");

// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");
NVF_CHECK(
!definition()->isA<GroupedReductionOp>(),
"For GroupedReductionOp, use TensorView::rFactor(const std::vector<int64_t>& axes, const std::vector<TensorView*>& tvs)");
Expand Down Expand Up @@ -933,6 +938,13 @@ std::vector<TensorView*> TensorView::rFactor(
this,
" its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op.");

// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain.
NVF_CHECK(
definition()->isA<MmaOp>() || !domain()->hasRoot(),
"Cannot call rfactor on the same view twice.");

NVF_CHECK(
definition()->outputs().size() == tvs.size(),
"Rfactor of a multi-output reduction not used correctly");
Expand Down
3 changes: 3 additions & 0 deletions csrc/transform_rfactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ std::pair<TensorDomain*, TensorDomain*> TransformRFactor::runReplay(
[](IterDomain* id) { return id->maybePartial(); }),
"rFactor of partial domains not allowed, but at least one found.");

// For hopper matmuls, the mma_result logical domain is reordered as [M, N, K]
// using commitLeafToLogical. Thus, the original logical domain is moved to
// the root domain. In this case, map from producer to consumer's root domain.
auto original_td_root = original_td->maybeRoot();

// Generate a new TensorDomain and set up map from one root to this one.
Expand Down

0 comments on commit e18741c

Please sign in to comment.