Skip to content

Commit

Permalink
Fix schedule and lowering of Hopper MMAs (#2176)
Browse files Browse the repository at this point in the history
This PR fixes the schedule and lowering of Hopper MMAs. Please see my
own review on this PR for detail.
After #2194, the schedule and
lowering of MMA finally makes sense now.

Note that MMA is still limited to have only one tile in smem. Supporting
multiple tiles is the next step.
  • Loading branch information
zasdfgbnm authored May 8, 2024
1 parent 5af7104 commit 7f7126d
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 69 deletions.
42 changes: 26 additions & 16 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,7 @@ static MmaInputSmemSwizzle getSwizzleMode(TensorView* tv) {
// Reference for smem strides:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#strides
void IndexLowering::handle(const MmaOp* mma) {
constexpr int64_t core_matrix_outer_size = 8;
Val* a = nullptr;
Val* b = nullptr;
if (mma->inA()->as<TensorView>()->getMemoryType() == MemoryType::Shared) {
Expand All @@ -1583,13 +1584,19 @@ void IndexLowering::handle(const MmaOp* mma) {
auto tv = mma->inA()->as<TensorView>();
auto base_addr = IrBuilder::baseAddressExpr(tv);
auto swizzle = getSwizzleMode(tv);
int64_t stride_bytes =
8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t leading_bytes = /*8x8 items each core matrix*/ 64L *
/*number of core matrices*/ (getM(mma->macro()) / 8L) *
/*bytes per item*/ 2L;
if (swizzle != MmaInputSmemSwizzle::None) {
// TODO: why???!!!
int64_t leading_bytes = core_matrix_outer_size *
getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t inner_size =
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN)
? getK(mma->macro())
: getM(mma->macro());
int64_t stride_bytes = core_matrix_outer_size *
/*number of core matrices, rounded up to handle padding */
roundUpToMultiple(inner_size * /*bytes per item*/ 2L,
getBytesFromSwizzle(swizzle));
if (swizzle == MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::NT || mma->layout() == MmaLayout::NN)) {
// tnspA and tnspB is ignored for NoSwizzle mode
std::swap(leading_bytes, stride_bytes);
}
auto matrix_desc = constructMatrixDescriptor(
Expand All @@ -1612,16 +1619,19 @@ void IndexLowering::handle(const MmaOp* mma) {
auto tv = mma->inB()->as<TensorView>();
auto swizzle = getSwizzleMode(tv);
auto base_addr = IrBuilder::baseAddressExpr(tv);
int64_t stride_bytes =
8L * getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t leading_bytes = /*8x8 items each core matrix*/ 64L *
int64_t leading_bytes = core_matrix_outer_size *
getBytesFromSwizzle(swizzle); // swizzle period in bytes
int64_t inner_size =
(mma->layout() == MmaLayout::TN || mma->layout() == MmaLayout::NN)
? getK(mma->macro())
: getN(mma->macro());
int64_t stride_bytes = core_matrix_outer_size *
/*number of core matrices, rounded up to handle padding */
roundUpToMultiple(getN(mma->macro()) / 8L,
getBytesFromSwizzle(swizzle) / 16L) *
/*bytes per item*/ 2L;
if (swizzle != MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::TN)) {
// TODO: why???!!!
roundUpToMultiple(inner_size * /*bytes per item*/ 2L,
getBytesFromSwizzle(swizzle));
if (swizzle == MmaInputSmemSwizzle::None &&
(mma->layout() == MmaLayout::TT || mma->layout() == MmaLayout::NT)) {
// tnspA and tnspB is ignored for NoSwizzle mode
std::swap(leading_bytes, stride_bytes);
}
auto matrix_desc = constructMatrixDescriptor(
Expand Down
12 changes: 6 additions & 6 deletions csrc/device_lower/pass/inline_ptx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,18 @@ class LowerToInlinePtx : public kir::ExprMutator {
/*scaleB=*/IrBuilder::create<Val>(1, DataType::Int32)};
auto layout = *mma->layout();
if (a_on_smem) {
// tnspA: if not K-major, then needs transpose
// tnspA
if (layout == MmaLayout::TT || layout == MmaLayout::TN) {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(0, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
}
}
// tnspB: if not K-major, then needs transpose
// tnspB
if (layout == MmaLayout::TN || layout == MmaLayout::NN) {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(0, DataType::Int32));
} else {
inputs.push_back(IrBuilder::create<Val>(1, DataType::Int32));
}
registerInsertBefore(
mma,
Expand Down
6 changes: 1 addition & 5 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,7 @@ class NVF_API TensorView : public Val {
//! have a matching thread swizzle with the mma operand/result.
//! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h .
void applyMmaSwizzle(MmaOperand operand);
// TODO: what is transpose 2? Why do we need it?
void applyMmaSwizzle(
MmaInputSmemSwizzle swizzle,
bool transpose,
bool transpose2 = false);
void applyMmaSwizzle(MmaInputSmemSwizzle swizzle);

//! Returns if this tensor view has swizzle operator on its tensor domain.
//! This is the temporary flag for indicating that the new swizzle
Expand Down
20 changes: 3 additions & 17 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,7 @@ void WarpMmaSwizzler::scheduleOperandRead(TensorView* tv, MmaOperand operand) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-swizzling-modes
void WarpMmaSwizzler::scheduleOperandRead(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool transpose,
bool transpose2) {
if (transpose) {
tv->reorder({{-2, -1}});
}
MmaInputSmemSwizzle swizzle) {
if (swizzle == MmaInputSmemSwizzle::None) {
// For no-swizzle case, the entire tile are divided into 8x8 core matrices,
// and each core matrix resides in a contiguous 8*8*2 bytes region in shared
Expand All @@ -868,16 +863,9 @@ void WarpMmaSwizzler::scheduleOperandRead(
// [Ko, K8, Mo, M8]
tv->reorder({{-2, -3}});
// [Ko, Mo, K8, M8]
if (transpose2) {
tv->reorder({{-2, -1}});
}
} else {
auto swizzle_size = getBytesFromSwizzle(swizzle) / 16;
// For example, [K, M]
if (transpose2) {
tv->reorder({{-2, -1}});
// [M, K]
}
tv->split(-2, 8);
tv->split(-1, 8);
// For example transpose2 == false
Expand All @@ -891,10 +879,8 @@ void WarpMmaSwizzler::scheduleOperandRead(
tv->split(-4, 8 / swizzle_size);
// [Ko, K2, K4, Moo, Mo2, M8]
tv->swizzle(SwizzleType::XOR, -5, -2);
if (!transpose2) {
tv->reorder({{-3, -5}});
// [Ko, Moo, K2, K4, Mo2, M8]
}
tv->reorder({{-3, -5}});
// [Ko, Moo, K2, K4, Mo2, M8]
}
tv->setAllocationDomain(tv->getLeafDomain(), true);
}
Expand Down
7 changes: 1 addition & 6 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,7 @@ class WarpMmaSwizzler {
//! after smem read.
//! The rightmost iterdomains must follow the m,n,k convention before calling.
static void scheduleOperandRead(TensorView* tv, MmaOperand operand);
// TODO: what is transpose2? Why do we need it?
static void scheduleOperandRead(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool transpose,
bool transpose2);
static void scheduleOperandRead(TensorView* tv, MmaInputSmemSwizzle swizzle);

//! Note [schedule of ldmatrix]
//! If you look at the doc of ldmatrix and mma for Turing and Ampere:
Expand Down
8 changes: 2 additions & 6 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,15 +1313,11 @@ void TensorView::applyMmaSwizzle(MmaOperand operand) {
}
}

void TensorView::applyMmaSwizzle(
MmaInputSmemSwizzle swizzle,
bool transpose,
bool transpose2) {
void TensorView::applyMmaSwizzle(MmaInputSmemSwizzle swizzle) {
NVF_ERROR(
getMemoryType() == MemoryType::Shared,
"Shared memory swizzle is only supported for shared memory");
mma_utils::WarpMmaSwizzler::scheduleOperandRead(
this, swizzle, transpose, transpose2);
mma_utils::WarpMmaSwizzler::scheduleOperandRead(this, swizzle);
}

void TensorView::commitLeafToRFactor() {
Expand Down
16 changes: 3 additions & 13 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,6 @@ class HopperRS : public HopperBase,
dtype = std::get<1>(GetParam());
layout = std::get<2>(GetParam());
swizzle_b = std::get<3>(GetParam());
if (layout != MmaLayout::TN) {
GTEST_SKIP() << "bugs to be fixed";
}
}
};

Expand Down Expand Up @@ -417,7 +414,7 @@ TEST_P(HopperRS, SingleTile) {
tv0->merge(1);
tv0->axis(1)->parallelize(ParallelType::TIDx);

tv1->applyMmaSwizzle(swizzle_b, layout == MmaLayout::TN);
tv1->applyMmaSwizzle(swizzle_b);

naivelyParallelize(tv1);

Expand Down Expand Up @@ -490,9 +487,6 @@ class HopperSS : public HopperBase,
layout = std::get<2>(GetParam());
swizzle_a = std::get<3>(GetParam());
swizzle_b = std::get<4>(GetParam());
if (layout != MmaLayout::TN) {
GTEST_SKIP() << "bugs to be fixed";
}
}
};

Expand Down Expand Up @@ -528,9 +522,6 @@ TEST_P(HopperSS, SingleTile) {
Fusion fusion;
FusionGuard fg(&fusion);

bool transpose_a = (layout == MmaLayout::NT || layout == MmaLayout::NN);
bool transpose_b = (layout == MmaLayout::TN || layout == MmaLayout::NN);

auto shapes = matmulAtInputShape3DHopperSS(
getM(macro), getN(macro), getK(macro), layout);

Expand Down Expand Up @@ -602,9 +593,8 @@ TEST_P(HopperSS, SingleTile) {
moveInnerBroadcastLeft(tv1);

// Hopper tensor core assumes K major, so we are using !transpose_a here.
tv0->applyMmaSwizzle(swizzle_a, !transpose_a);
tv1->setMemoryType(MemoryType::Shared);
tv1->applyMmaSwizzle(swizzle_b, transpose_b, transpose_a);
tv0->applyMmaSwizzle(swizzle_a);
tv1->applyMmaSwizzle(swizzle_b);

naivelyParallelize(tv0);
naivelyParallelize(tv1);
Expand Down

0 comments on commit 7f7126d

Please sign in to comment.