diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index eb160d2e2de..db32870e713 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1202,55 +1202,6 @@ TensorViewDetails getDetailsFor(const std::vector& dims) { return details; } -MmaLayout getInputLayout( - const TensorViewDetails& in_a, - const TensorViewDetails& in_b, - const MmaOp::AxesData& m_axes, - const MmaOp::AxesData& n_axes, - const MmaOp::AxesData& k_axes) { - // TT layout (b - broadcast, r - reduction): - // A = [M, K, b] - // B = [b, K, N] - // C = [M, r, N] (root domain) - if ((m_axes.front() < in_a.bcasts.front()) && - (k_axes.front() < in_a.bcasts.front()) && - (in_b.bcasts.front() < k_axes.front()) && - (in_b.bcasts.front() < n_axes.front())) { - return MmaLayout::TT; - } - // TN layout (b - broadcast, r - reduction): - // A = [M, b, K] - // B = [b, N, K] - // C = [M, N, r] (root domain) - if ((m_axes.front() < in_a.bcasts.front()) && - (in_a.bcasts.front() < k_axes.front()) && - (in_b.bcasts.front() < n_axes.front()) && - (in_b.bcasts.front() < k_axes.front())) { - return MmaLayout::TN; - } - // NT layout (b - broadcast, r - reduction): - // A = [K, M, b] - // B = [K, b, N] - // C = [r, M, N] (root domain) - if ((k_axes.front() < in_a.bcasts.front()) && - (m_axes.front() < in_a.bcasts.front()) && - (k_axes.front() < in_b.bcasts.front()) && - (in_b.bcasts.front() < n_axes.front())) { - return MmaLayout::NT; - } - // NN layout (b - broadcast, r - reduction): - // A = [b, K, M] - // B = [N, K, b] - // C = [N, r, M] (root domain) - if ((in_a.bcasts.front() < k_axes.front()) && - (k_axes.front() < m_axes.front()) && (n_axes.front() < k_axes.front()) && - (k_axes.front() < in_b.bcasts.front())) { - return MmaLayout::NN; - } - - NVF_THROW("Unsupported input layout"); -} - MmaOpDetails getMmaOpDetails( TensorView* out, TensorView* in_a, @@ -1383,15 +1334,6 @@ MmaOpDetails getMmaOpDetails( !details.k_axes.empty(), "MmaOp inputs must define at least a single K dimension"); - // TODO: for tensor contraction / split-k uses of MmaOp different input layout - // rules may be needed - details.input_layout = getInputLayout( - in_a_details, - in_b_details, - details.m_axes, - details.n_axes, - details.k_axes); - return details; } diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 6ce59fb95c4..343afd88874 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -38,8 +38,6 @@ struct MmaOpDetails { // Concrete or broadcast axes that are present in all inputs // and output AxesData batch_axes; - // A placeholder for mma input layout - std::optional input_layout = std::nullopt; }; // A helper structure with pieces of information about TensorView diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 4b3b404a889..9b1333f6339 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -140,6 +140,69 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) { NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +// Single batch dimension which is broadcast +TEST_P(MatmulTestWithLayout, AmpereMatmulBroadcastBatch) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); + + auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); + auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + // Broadcast inputs to 1, M, 1, K and 1, 1, N, K + tv0 = broadcast(tv0, {true, false, false, false}); + tv1 = broadcast(tv1, {true, false, false, false}); + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 4}; + mparams.mma_macro = MmaMacro::Ampere_16_8_16; + mparams.tile_sizes = gemm_tile; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = true; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + auto inputs = matmulAtInput3DTuring(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = + atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout) + .unsqueeze(0); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248;