Skip to content

Commit

Permalink
Cleanup applyMmaSwizzleForTMALoad (#3077)
Browse files Browse the repository at this point in the history
In the past, we had a `permute_outer_dim` parameter for
`applyMmaSwizzleForTMALoad` because strides in the Mma Op's matrix
descriptor were hard-coded. With hardcoded strides, we had zero
flexibility, and had to use multiple TMA to load smem buffer into a very
specific format. After #3002, the
strides were computed based on the schedule, so we do have the
flexibility to choose a memory layout that uses less TMA instructions to
load.
  • Loading branch information
zasdfgbnm authored Oct 7, 2024
1 parent 888f720 commit c7c5de8
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 111 deletions.
4 changes: 1 addition & 3 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ class NVF_API TensorView : public Val {
//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are inputs of a MmaOp or are loaded
//! using TMA.
void applyMmaSwizzleForTMALoad(
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim = true);
void applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle);

//! Returns if this tensor view has swizzle operator on its tensor domain.
//! This is the temporary flag for indicating that the new swizzle
Expand Down
18 changes: 2 additions & 16 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,14 +929,9 @@ void MmaSwizzler::parallelizeAsBulkSkippingFirstIDs(
}
}

// Please note that we currently do not fully support
// not splitting the outer dimension. This only works when
// the inner-dimension is not split, that is the inner dim
// is less or equal to the swizzle size (in bytes).
void MmaSwizzler::scheduleTMALoadForMma(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim) {
MmaInputSmemSwizzle swizzle) {
// In the comments below I have kept K as the outer dimension. That is
// just to have a concrete running example - it can be inner or outer.

Expand Down Expand Up @@ -968,16 +963,7 @@ void MmaSwizzler::scheduleTMALoadForMma(
// [NO, K, NI] ->
// [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)]
tv->swizzleTMABox(swizzle);

// If the outer dim is split, then we pull out KO to be outside NO
// and KO and NO are both not marked bulk parallel, else NO is outer
// and only NO is not marked bulk parallel.
if (permute_outer_dim) {
// [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] ->
// [KO(2), NO(2), KIO(2), KII(4), NIO(2), NII(8)]
tv->reorder({{-6, -5}});
}
num_ids_to_skip += permute_outer_dim ? 2 : 1;
num_ids_to_skip += 1;
}

parallelizeAsBulkSkippingFirstIDs(tv, num_ids_to_skip);
Expand Down
3 changes: 1 addition & 2 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ class MmaSwizzler {
//! outermost.
static void scheduleTMALoadForMma(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim = true);
MmaInputSmemSwizzle swizzle);

//! Parallelize all dims as bulk expect the first dims mentioned in the second
//! param.
Expand Down
7 changes: 2 additions & 5 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,18 +1388,15 @@ void TensorView::swizzleTMABox(MmaInputSmemSwizzle swizzle) {
this->swizzle(SwizzleType::XOR, -4, -2);
}

void TensorView::applyMmaSwizzleForTMALoad(
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim) {
void TensorView::applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle) {
NVF_ERROR(
getMemoryType() == MemoryType::Shared,
"Shared memory swizzle is only supported for shared memory");
NVF_ERROR(
definition()->as<LoadStoreOp>()->opType() ==
LoadStoreOpType::CpAsyncBulkTensorTile,
"Operation requires a TMA operation");
mma_utils::MmaSwizzler::scheduleTMALoadForMma(
this, swizzle, permute_outer_dim);
mma_utils::MmaSwizzler::scheduleTMALoadForMma(this, swizzle);
}

void TensorView::commitLeafToLogical() {
Expand Down
85 changes: 0 additions & 85 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,91 +878,6 @@ TEST_P(HopperRS, SingleTileWithTMALoadStore) {
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_P(HopperRS, SingleTileWithTMALoadOuterDimNotSplit) {
if (layout == MmaLayout::TT) {
GTEST_SKIP() << "Skipping test as we only handle TN layout in this test";
}

Fusion fusion;
FusionGuard fg(&fusion);

auto shapes = matmulAtInputShape3DHopperRS(
getM(macro), getN(macro), getK(macro), layout);

auto tv0 = makeContigConcreteTensor(shapes.first, dtype);
auto tv1 = makeContigConcreteTensor(shapes.second, dtype);
fusion.addInput(tv0);
fusion.addInput(tv1);

// Just doing a gmem->register copy
tv0 = set(tv0);
// Just doing a gmem->smem copy
tv1 = set(tv1);
tv1->setMemoryType(MemoryType::Shared);
tv1->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

auto tv2 = fusedMultiplySum(tv0, tv1, {layout == MmaLayout::TT ? 1 : 2});

fusion.addOutput(tv2);

auto mma_ops = ir_utils::getOpsOfType<MmaOp>(&fusion);
NVF_CHECK(
1 == mma_ops.size(),
"Invalid number of MmaOp instances in fusion definition, expected 1, got ",
mma_ops.size());
mma_ops.front()->setMacro(macro);

auto tv2c = tv2->cacheBefore();

matmul_utils::moveInnerBroadcastLeft(tv0);
tv0->applyMmaSwizzle(MmaOperand::A);

tv0->merge(1);
tv0->merge(1);
tv0->axis(1)->parallelize(ParallelType::TIDx);

// In this case we don't split the outer dimension, thus having
// fewer TMA loads.
tv1->applyMmaSwizzleForTMALoad(swizzle_b, /* don't split outer dim*/ false);

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2c->getLoopDomain());
tv2c->setAllocationDomain(s.as<IterDomain*>(), true);
// Note: according to internal doc "Representing ldmatrix", we need both a
// read domain and a write domain to correctly represent MmaOp. Without this
// new mechanism, there is no correct loop domain, and the only choices are
// either we want to represent the smem read well, or represent the register
// write well. We choose to represent the smem read well here. Likely, this
// means we will not be able to have multiple tiles in register, but we can
// workaround this by always inlining the MmaOp most. We should fix this
// after we implemented the new read/write domain mechanism.
tv2c->axis(-1)->parallelize(ParallelType::Mma);
tv2c->axis(-2)->parallelize(ParallelType::Mma);
tv2c->axis(-3)->parallelize(ParallelType::Mma);
}
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
tv2->setLoopDomain(s.as<IterDomain*>());
}

auto inputs = matmulAtInput3DHopperRS(
getM(macro), getN(macro), getK(macro), layout, data_type_to_aten(dtype));

FusionExecutor fe;
fe.compileFusion(
&fusion, {inputs.first, inputs.second}, LaunchParams(), matmul_cparams);

auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref = atMatmul(
inputs.first.squeeze().to(at::kFloat),
inputs.second.squeeze().to(at::kFloat),
layout);
EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5));
}

TEST_P(HopperRS, MultipleTile) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit c7c5de8

Please sign in to comment.