Skip to content

Commit

Permalink
Refactor StMatrix Logic (#3553)
Browse files Browse the repository at this point in the history
This PR contains the actual code changes from
#3552.

1. Fix `hardCodedIndexGenerationForStMatrixSwizzle` for stmatrix (16, 8). The lane column is always `lane_id / 16` not `lane_id / stsm_n_tile`.
2. Fix `tmaSwizzleSharedMemory` by ensuring box inner dimension is evenly divisible by swizzle size.
3. Clean `scheduleTMAStoreForMmaOutput` and `scheduleStMatrixForMmaOutput` by remove unnecessary scheduling and correcting `num_ids_to_skip`.
4. Replace `mma_macro_to_str_map` with `macroToString` function.
  • Loading branch information
rdspring1 authored Dec 10, 2024
1 parent 98352c4 commit 9575fd6
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 66 deletions.
6 changes: 4 additions & 2 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,7 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
constexpr int64_t warp_size = 32;
constexpr int64_t swizzle_row_size = 8;
constexpr int64_t stsm_column_size = 8;
constexpr int64_t max_stsm_n_tile = 16;
constexpr int64_t megabank_size_bytes = 16;

// Derived constants
Expand All @@ -1987,7 +1988,8 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
// NvFuser Val for constants
Val* warp_size_val = IrBuilder::create<Val>(warp_size, DataType::Index);
Val* stsm_m_tile_val = IrBuilder::create<Val>(stsm_m_tile, DataType::Index);
Val* stsm_n_tile_val = IrBuilder::create<Val>(stsm_n_tile, DataType::Index);
Val* max_stsm_n_tile_val =
IrBuilder::create<Val>(max_stsm_n_tile, DataType::Index);
Val* stsm_n_tile_stride_val =
IrBuilder::create<Val>(stsm_n_tile_stride, DataType::Index);
Val* swizzle_row_size_val =
Expand Down Expand Up @@ -2020,7 +2022,7 @@ Val* hardCodedIndexGenerationForStMatrixSwizzle(
row = GpuLower::current()->commonScalarMap().hoistScalar(row, {loop});

// Calculate Column
Val* lane_col = SimplifyingIrBuilder::divExpr(lane_id, stsm_n_tile_val);
Val* lane_col = SimplifyingIrBuilder::divExpr(lane_id, max_stsm_n_tile_val);
Val* iter_col =
SimplifyingIrBuilder::mulExpr(inner_index, stsm_n_tile_stride_val);
Val* col = SimplifyingIrBuilder::addExpr(lane_col, iter_col);
Expand Down
39 changes: 15 additions & 24 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,11 @@ MmaInputSmemSwizzle tmaSwizzleSharedMemory(TensorView* shared_mem_tv) {
const int64_t B64_elements = 64 / dataTypeSize(dtype);
const int64_t B32_elements = 32 / dataTypeSize(dtype);

if (inner_dim_size >= B128_elements) {
if (inner_dim_size % B128_elements == 0) {
return MmaInputSmemSwizzle::B128;
} else if (inner_dim_size >= B64_elements) {
} else if (inner_dim_size % B64_elements == 0) {
return MmaInputSmemSwizzle::B64;
} else if (inner_dim_size >= B32_elements) {
} else if (inner_dim_size % B32_elements == 0) {
return MmaInputSmemSwizzle::B32;
} else {
NVF_THROW("Unsupported swizzle size for TMA shared memory mma inputs");
Expand Down Expand Up @@ -1081,10 +1081,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {

// Schedule shared memory cache; Output from StMatrix
scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);

// Schedule global memory output; Output from TMA Store
scheduleTMAStoreForMmaOutput(d, swizzle, tma_m, tma_n);
scheduleTMAStoreForMmaOutput(d, swizzle);
}
}
}
Expand Down Expand Up @@ -1244,9 +1244,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t tile_m,
int64_t tile_n,
int64_t tma_m,
int64_t tma_n) {
int64_t tile_n) {
NVF_ERROR(
((tile_m == 16 && tile_n == 16) || (tile_m == 16 && tile_n == 8)),
"We only support 16x16 and 16x16 stmatrix now");
Expand All @@ -1258,8 +1256,10 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());

// Create tma store allocation domain with swizzle
scheduleTMAStoreForMmaOutput(tv, swizzle, tma_m, tma_n);
if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
scheduleTMAStoreForMmaOutput(tv, swizzle);
}

tv->setLoopDomain(s.as<IterDomain*>());

Expand All @@ -1286,18 +1286,11 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(

void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t m,
int64_t n) {
// [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)]
tv->split(-2, m);
tv->split(-1, n);
// [MO(1), MI(m), NO(1), NI(n)] -> [MO(1), NO(1), MI(m), NI(n)]
tv->reorder({{-2, -3}});

// [BDX, BDY, TDY, MO(1), NO(1), MI, NI]
// skip the first 5 iterDomains
int64_t num_ids_to_skip = 5;
MmaInputSmemSwizzle swizzle) {
// [BDX, BDY, TDY, MI, NI]
// skip all but last 2 iterDomains
int64_t num_ids_to_skip =
static_cast<int64_t>(tv->getLoopDomain().size() - 2);

NVF_ERROR(num_ids_to_skip >= 0);
if (swizzle == MmaInputSmemSwizzle::None) {
Expand All @@ -1308,8 +1301,6 @@ void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
tv->split(-1, 8);
// [Ko, K8, No, N8]
tv->reorder({{-2, -3}});
// [Ko, No, K8, N8]
num_ids_to_skip += 2;
} else {
auto dtype = tv->getDataType().value();

Expand Down
8 changes: 2 additions & 6 deletions csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,13 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t tile_m,
int64_t tile_n,
int64_t tma_m,
int64_t tma_n);
int64_t tile_n);

//! Schedules the copy operation of output of a Mma op which resided in the
//! shared memory to global memory.
void scheduleTMAStoreForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t m,
int64_t n);
MmaInputSmemSwizzle swizzle);

// Map TensorView's iterDomain to its ValGroup.
// Then, find the MatmulDimRole for the ValGroup.
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3134,7 +3134,7 @@ std::string hopperTestName(
os << (a_k_inner ? "K" : "M");
os << (b_k_inner ? "K" : "N");
os << "_" << M << "_" << N << "_" << K;
os << "_MmaMacro_" << mma_macro_to_str_map.at(mma_macro);
os << "_MmaMacro_" << macroToString(mma_macro);
if (use_smem_epilogue) {
os << "_tma_store";
}
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,4 +816,12 @@ bool isVectorized(TensorView* tv) {
return false;
}

std::string macroToString(const MmaMacro macro) {
std::stringstream ss;
ss << "m" << getM(macro);
ss << "_n" << getN(macro);
ss << "_k" << getK(macro);
return ss.str();
}

} // namespace nvfuser
34 changes: 1 addition & 33 deletions tests/cpp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -703,39 +703,7 @@ static auto kAllHopperMacros = testing::Values(
MmaMacro::Hopper_64_248_16,
MmaMacro::Hopper_64_256_16);

static std::unordered_map<MmaMacro, std::string> mma_macro_to_str_map = {
{MmaMacro::Hopper_64_8_16, "m64_n8_k16"},
{MmaMacro::Hopper_64_16_16, "m64_n16_k16"},
{MmaMacro::Hopper_64_24_16, "m64_n24_k16"},
{MmaMacro::Hopper_64_32_16, "m64_n32_k16"},
{MmaMacro::Hopper_64_40_16, "m64_n40_k16"},
{MmaMacro::Hopper_64_48_16, "m64_n48_k16"},
{MmaMacro::Hopper_64_56_16, "m64_n56_k16"},
{MmaMacro::Hopper_64_64_16, "m64_n64_k16"},
{MmaMacro::Hopper_64_72_16, "m64_n72_k16"},
{MmaMacro::Hopper_64_80_16, "m64_n80_k16"},
{MmaMacro::Hopper_64_88_16, "m64_n88_k16"},
{MmaMacro::Hopper_64_96_16, "m64_n96_k16"},
{MmaMacro::Hopper_64_104_16, "m64_n104_k16"},
{MmaMacro::Hopper_64_112_16, "m64_n112_k16"},
{MmaMacro::Hopper_64_120_16, "m64_n120_k16"},
{MmaMacro::Hopper_64_128_16, "m64_n128_k16"},
{MmaMacro::Hopper_64_136_16, "m64_n136_k16"},
{MmaMacro::Hopper_64_144_16, "m64_n144_k16"},
{MmaMacro::Hopper_64_152_16, "m64_n152_k16"},
{MmaMacro::Hopper_64_160_16, "m64_n160_k16"},
{MmaMacro::Hopper_64_168_16, "m64_n168_k16"},
{MmaMacro::Hopper_64_176_16, "m64_n176_k16"},
{MmaMacro::Hopper_64_184_16, "m64_n184_k16"},
{MmaMacro::Hopper_64_192_16, "m64_n192_k16"},
{MmaMacro::Hopper_64_200_16, "m64_n200_k16"},
{MmaMacro::Hopper_64_208_16, "m64_n208_k16"},
{MmaMacro::Hopper_64_216_16, "m64_n216_k16"},
{MmaMacro::Hopper_64_224_16, "m64_n224_k16"},
{MmaMacro::Hopper_64_232_16, "m64_n232_k16"},
{MmaMacro::Hopper_64_240_16, "m64_n240_k16"},
{MmaMacro::Hopper_64_248_16, "m64_n248_k16"},
{MmaMacro::Hopper_64_256_16, "m64_n256_k16"}};
std::string macroToString(const MmaMacro macro);

// Utility to generate matmul input tensors based on given layout
at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout);
Expand Down

0 comments on commit 9575fd6

Please sign in to comment.