Skip to content

Commit

Permalink
Remove MmaOpDetails::input_layout and getInputLayout (#3322)
Browse files Browse the repository at this point in the history
There is no reason for us to check the Mma layout anymore when defining
an MmaOp, since that is all handled in the scheduler now. I also added a
test where a new batch dimension is broadcasted before defining the
MmaOp.

Fixes #2273.
  • Loading branch information
jacobhinkle authored Oct 31, 2024
1 parent 4cf9533 commit abdc3e1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 60 deletions.
58 changes: 0 additions & 58 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,55 +1224,6 @@ TensorViewDetails getDetailsFor(const std::vector<IterDomain*>& dims) {
return details;
}

MmaLayout getInputLayout(
const TensorViewDetails& in_a,
const TensorViewDetails& in_b,
const MmaOp::AxesData& m_axes,
const MmaOp::AxesData& n_axes,
const MmaOp::AxesData& k_axes) {
// TT layout (b - broadcast, r - reduction):
// A = [M, K, b]
// B = [b, K, N]
// C = [M, r, N] (root domain)
if ((m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_a.bcasts.front()) &&
(in_b.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaLayout::TT;
}
// TN layout (b - broadcast, r - reduction):
// A = [M, b, K]
// B = [b, N, K]
// C = [M, N, r] (root domain)
if ((m_axes.front() < in_a.bcasts.front()) &&
(in_a.bcasts.front() < k_axes.front()) &&
(in_b.bcasts.front() < n_axes.front()) &&
(in_b.bcasts.front() < k_axes.front())) {
return MmaLayout::TN;
}
// NT layout (b - broadcast, r - reduction):
// A = [K, M, b]
// B = [K, b, N]
// C = [r, M, N] (root domain)
if ((k_axes.front() < in_a.bcasts.front()) &&
(m_axes.front() < in_a.bcasts.front()) &&
(k_axes.front() < in_b.bcasts.front()) &&
(in_b.bcasts.front() < n_axes.front())) {
return MmaLayout::NT;
}
// NN layout (b - broadcast, r - reduction):
// A = [b, K, M]
// B = [N, K, b]
// C = [N, r, M] (root domain)
if ((in_a.bcasts.front() < k_axes.front()) &&
(k_axes.front() < m_axes.front()) && (n_axes.front() < k_axes.front()) &&
(k_axes.front() < in_b.bcasts.front())) {
return MmaLayout::NN;
}

NVF_THROW("Unsupported input layout");
}

MmaOpDetails getMmaOpDetails(
TensorView* out,
TensorView* in_a,
Expand Down Expand Up @@ -1405,15 +1356,6 @@ MmaOpDetails getMmaOpDetails(
!details.k_axes.empty(),
"MmaOp inputs must define at least a single K dimension");

// TODO: for tensor contraction / split-k uses of MmaOp different input layout
// rules may be needed
details.input_layout = getInputLayout(
in_a_details,
in_b_details,
details.m_axes,
details.n_axes,
details.k_axes);

return details;
}

Expand Down
2 changes: 0 additions & 2 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ struct MmaOpDetails {
// Concrete or broadcast axes that are present in all inputs
// and output
AxesData batch_axes;
// A placeholder for mma input layout
std::optional<MmaLayout> input_layout = std::nullopt;
};

// A helper structure with pieces of information about TensorView
Expand Down
63 changes: 63 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,69 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) {
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

// Single batch dimension which is broadcast
TEST_P(MatmulTestWithLayout, AmpereMatmulBroadcastBatch) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 248;

Fusion fusion;
FusionGuard fg(&fusion);

auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout);

auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half);
auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half);

fusion.addInput(tv0);
fusion.addInput(tv1);

tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A);
tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B);
// Broadcast inputs to 1, M, 1, K and 1, 1, N, K
tv0 = broadcast(tv0, {true, false, false, false});
tv1 = broadcast(tv1, {true, false, false, false});
auto tv2 = fusedMultiplySum(tv0, tv1, {-1});

fusion.addOutput(tv2);

MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = GemmTile(128, 128, 32);
gemm_tile.warp_tile = GemmTile(64, 64, 32);
gemm_tile.instruction_tile = GemmTile(16, 8, 16);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};
mparams.mma_macro = MmaMacro::Ampere_16_8_16;
mparams.tile_sizes = gemm_tile;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

auto inputs = matmulAtInput3DTuring(M, N, K, layout);

FusionExecutor fe;
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
8,
0,
fe.compileFusion(
&fusion,
{inputs.first, inputs.second},
LaunchParams(),
matmul_cparams));
ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty());
ASSERT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel()));
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
auto tref =
atMatmul(
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout)
.unsqueeze(0);
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) {
// Keep multiples of 8 to keep vectorizable.
int M = 504, N = 136, K = 248;
Expand Down

0 comments on commit abdc3e1

Please sign in to comment.