From b53c103b22cd9febeecfe6ea0377649365382763 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 11 Dec 2024 09:46:13 -0800 Subject: [PATCH 1/7] Remove unnecessary NVF_API (#3562) --- csrc/kernel_ir.h | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index f5f062cdbb7..60421db1995 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -77,7 +77,7 @@ class Predicate final : public Val { std::string toString(int indent_size = 0) const override; - NVF_API std::string toInlineString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; PredicateType predicate_type() const { return ptype_; @@ -148,7 +148,7 @@ class Predicate final : public Val { Val* value_ = nullptr; }; -class NVF_API TensorIndex final : public Val { +class TensorIndex final : public Val { public: TensorIndex( IrBuilderPasskey, @@ -252,7 +252,7 @@ class Asm final : public Expr { //! is required as an intermediate within a kernel. The extent is the expression //! of the size of the buffer that is generated from the TensorView that //! describes the output of an operation. -class NVF_API Allocate final : public Expr { +class Allocate final : public Expr { public: using Expr::Expr; @@ -385,7 +385,7 @@ class NVF_API Allocate final : public Expr { // // TODO(kir): change name to SyncThreads as we could have other barriers. // -class NVF_API BlockSync final : public Expr { +class BlockSync final : public Expr { public: using Expr::Expr; @@ -408,7 +408,7 @@ class NVF_API BlockSync final : public Expr { // Synchronize all blocks in device, implies cooperative group launch is // required. -class NVF_API GridSync final : public Expr { +class GridSync final : public Expr { public: using Expr::Expr; @@ -436,7 +436,7 @@ class NVF_API GridSync final : public Expr { }; // PTX: fence.proxy.async -class NVF_API FenceAsyncProxy final : public Expr { +class FenceAsyncProxy final : public Expr { public: using Expr::Expr; @@ -453,7 +453,7 @@ class NVF_API FenceAsyncProxy final : public Expr { }; // PTX: wgmma.fence.sync.aligned -class NVF_API WgMmaFence final : public Expr { +class WgMmaFence final : public Expr { public: using Expr::Expr; @@ -469,7 +469,7 @@ class NVF_API WgMmaFence final : public Expr { std::string toInlineString(int indent_size = 0) const override; }; -class NVF_API MBarrierInit final : public Expr { +class MBarrierInit final : public Expr { public: using Expr::Expr; explicit MBarrierInit( @@ -495,7 +495,7 @@ class NVF_API MBarrierInit final : public Expr { } }; -class NVF_API MBarrierInvalidate final : public Expr { +class MBarrierInvalidate final : public Expr { public: using Expr::Expr; explicit MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier); @@ -514,7 +514,7 @@ class NVF_API MBarrierInvalidate final : public Expr { } }; -class NVF_API MBarrierArrive final : public Expr { +class MBarrierArrive final : public Expr { public: using Expr::Expr; explicit MBarrierArrive(IrBuilderPasskey passkey, Val* state, Val* mbarrier); @@ -544,7 +544,7 @@ class NVF_API MBarrierArrive final : public Expr { // This is usually used to specify the number of bytes that will be // transferred for cp.async and cp.async.bulk, so that future mbarrier.wait // can wait for the completion of the transfer. -class NVF_API MBarrierArriveExpectTx final : public Expr { +class MBarrierArriveExpectTx final : public Expr { public: using Expr::Expr; explicit MBarrierArriveExpectTx( @@ -578,7 +578,7 @@ class NVF_API MBarrierArriveExpectTx final : public Expr { } }; -class NVF_API MBarrierWait final : public Expr { +class MBarrierWait final : public Expr { public: using Expr::Expr; explicit MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state); @@ -601,7 +601,7 @@ class NVF_API MBarrierWait final : public Expr { } }; -class NVF_API MBarrierWaitParity final : public Expr { +class MBarrierWaitParity final : public Expr { public: using Expr::Expr; explicit MBarrierWaitParity( @@ -796,7 +796,7 @@ class UpdateMagicZero final : public Expr { //! //! TODO(kir): this is not a real expression //! -class NVF_API IfThenElse final : public Expr { +class IfThenElse final : public Expr { public: using Expr::Expr; @@ -915,7 +915,7 @@ class GridReduction final : public ReductionOp { } }; -class NVF_API GroupedGridReduction final : public GroupedReductionOp { +class GroupedGridReduction final : public GroupedReductionOp { public: using GroupedReductionOp::GroupedReductionOp; @@ -1006,7 +1006,7 @@ class NVF_API GroupedGridReduction final : public GroupedReductionOp { //! //! This node provides KernelExecutor the information it needs to allocate the //! broadcast and sync buffers. -class NVF_API GridBroadcast final : public Expr { +class GridBroadcast final : public Expr { public: using Expr::Expr; @@ -1117,7 +1117,7 @@ class GridWelford final : public Expr { } }; -class NVF_API GroupedGridWelford final : public GroupedWelfordOp { +class GroupedGridWelford final : public GroupedWelfordOp { public: using GroupedWelfordOp::GroupedWelfordOp; @@ -1211,7 +1211,7 @@ class NVF_API GroupedGridWelford final : public GroupedWelfordOp { //! Represents a WelfordOp with the division by count is hoisted out //! of an innermost loop -class NVF_API VectorizedWelfordOp final : public WelfordOp { +class VectorizedWelfordOp final : public WelfordOp { public: using WelfordOp::WelfordOp; From 1b299286d5c64df3427a031222fa96c8244e3a6c Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 11 Dec 2024 11:43:22 -0800 Subject: [PATCH 2/7] Support 2D inner reduction scheduler with autotuning (#3456) # Summary: This PR creates an SOL autotuning script for the 2d inner reduction scheduler. It trains a random forest to predict the best performing configuration for the reduction scheduler. # Inner Reduction Fusions: 1. Sum --- `y = sum(x, dim=-1)` 2. Add Sum --- `z = sum(x1 + x2 + x3 + x4, dim=-1)` 3. Tanh Sum --- `y = sum(tanh(x), dim=-1)` 4. Exp Sum --- `z = sum(exp(x), dim=-1)` --- .../autotune_inner_reduction.py | 405 ++++++++++++++++++ doc/dev/python_scheduling/autotune_utils.py | 13 + 2 files changed, 418 insertions(+) create mode 100644 doc/dev/python_scheduling/autotune_inner_reduction.py diff --git a/doc/dev/python_scheduling/autotune_inner_reduction.py b/doc/dev/python_scheduling/autotune_inner_reduction.py new file mode 100644 index 00000000000..c43a20f5767 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_inner_reduction.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +from nvfuser import FusionDefinition, SchedulerType, DataType, ParallelType +from enum import Enum +from dataclasses import dataclass + +from autotune_utils import ( + ScriptConfiguration, + collect_data, + separate_data, + test_model_rmse, + test_model, + ceil_div, + floor_div, +) + + +# ================================ Description ================================ + +# This script defines four inner reduction fusions: +# +# 1. Inner Sum +# y = sum(x, dim=-1) +# +# 2. Add Sum +# z = sum(x1 + x2 + x3 + x4, dim=-1) +# +# 3. Tanh Sum +# y = sum(tanh(x), dim=-1) +# +# 4. Exp Sum +# z = sum(exp(x), dim=-1) +# +# Script Sequence: +# +# 1. Define a nvfuser fusion and its pytorch eager mode reference. +# +# 2. Profile the CUDA kernel performance by iterating over a set of input +# arguments and scheduler configurations. +# +# 3. Train a regression model to predict the desired performance metric given +# some input arguments and a scheduler configuration. +# +# 4. Measure the performance of the regression model. +# - Calculate RMSE of predicted and actual performance on test set. +# - Find the configuration with the best performance using regression model. +# Then, compare against the heuristic configuration selected by nvfuser. +# - For a specific batch size, gather performance across a range of hidden +# sizes. Calculate performance for best predicted and nvfuser +# configurations. Plot a chart comparing performance using matplotlib. +# +# The selected performance metric is effective_bandwidth_gbs. The empirical +# scheduler selects the configuration that has the highest predicted +# effective_bandwidth_gbs. + +# ============================================================================= + + +class AutotuneInnerReduction: + class FUSION(Enum): + INNER_SUM = 1 + ADD_SUM = 2 + TANH_SUM = 3 + EXP_SUM = 4 + + @dataclass(unsafe_hash=True) + class InnerReductionConfiguration: + # The vectorization factor for inner reduction domain. + vectorize_factor: int = 1 + # The unroll factor for the outer iteration domain. + unroll_factor: int = 1 + # The grid size for the outer iteration domain. + # If grdim > 1, then godim corresponds with y axis of the grid. + # Otherwise, it is the x axis of the grid. + godim: int = -1 + # The grid size for the inner reduction domain. It corresponds + # with x axis of the grid when it is >1. + grdim: int = -1 + # The x axis of CTA. It corresponds with inner reduction domain. + bdimx: int = -1 + # The y axis of CTA. It corresponds with outer reduction domain. + # If it is non-zero, then there are multiple reduction per CTA. + bdimy: int = -1 + + def __init__(self, selected_fusion): + self.selected_fusion = selected_fusion + + # gpu device properties are defined globally + assert torch.cuda.is_available() + self.gpu_properties = torch.cuda.get_device_properties(device=0) + + def __repr__(self): + return f"inner_reduction_{self.selected_fusion.name}" + + def convert_to_inner_reduction_params(self, scheduler_config, reduction_params): + warp_size = 32 + max_number_of_threads_cta = 1024 + grid_x_limit = 2147483647 + grid_y_limit = 65535 + + reduction_params.schedule_3D = False + reduction_params.fastest_dim = True + reduction_params.cross_block_inner_reduction = True + reduction_params.block_dim_inner_reduction = ParallelType.block_x + reduction_params.cross_grid_inner_reduction = scheduler_config.grdim > 1 + reduction_params.multiple_reds_per_blk = scheduler_config.bdimy > 1 + reduction_params.pad_inner_reduction_to_warp = ( + scheduler_config.bdimx > warp_size + ) and ( + (scheduler_config.bdimx * scheduler_config.bdimy) + < max_number_of_threads_cta + ) + reduction_params.unroll_factor_inner_reduction = ( + scheduler_config.vectorize_factor + ) + reduction_params.vectorize_inner_reduction = ( + scheduler_config.vectorize_factor > 1 + ) + + if scheduler_config.bdimy > 1: + reduction_params.block_dim_iter_dom = ParallelType.block_y + + reduction_params.unroll_factor_iter_dom = scheduler_config.unroll_factor + + gdimx = -1 + gdimy = -1 + + if scheduler_config.grdim > 1: + reduction_params.grid_dim_inner_reduction = ParallelType.grid_x + reduction_params.grid_dim_iter_dom = ParallelType.grid_y + + reduction_params.split_grid_dim_iter_dom_inner = True + gdimx = min(scheduler_config.grdim, grid_x_limit) + gdimy = min(scheduler_config.godim, grid_y_limit) + if scheduler_config.godim > grid_y_limit: + reduction_params.split_grid_dim_iter_dom_outer = True + else: + reduction_params.grid_dim_iter_dom = ParallelType.grid_x + gdimx = min(scheduler_config.godim, grid_x_limit) + if scheduler_config.godim > grid_x_limit: + reduction_params.split_grid_dim_inner_reduction = True + + reduction_params.lparams.gdimx = gdimx + reduction_params.lparams.gdimy = gdimy + + # Reset CTA dimensions to avoid failing LaunchParams::assertValid + reduction_params.lparams.bdimx = -1 + reduction_params.lparams.bdimy = -1 + reduction_params.lparams.bdimz = -1 + + reduction_params.lparams.bdimx = scheduler_config.bdimx + reduction_params.lparams.bdimy = scheduler_config.bdimy + + # For reduction scheduler, we test the cartesian product of vectorization and + # unroll factors. + def generate_scheduler_configurations(self, input_shape): + threads_per_cta_options = [128, 256, 512, 1024] + vectorization_factor_options = [1, 2, 4, 8] + unroll_factor_options = list(range(1, 11)) + warp_size = 32 + + num_iterations, num_reductions = input_shape + + for threads_per_cta, vectorize_factor, unroll_factor in itertools.product( + threads_per_cta_options, vectorization_factor_options, unroll_factor_options + ): + scheduler_config = self.InnerReductionConfiguration( + vectorize_factor=vectorize_factor, unroll_factor=unroll_factor + ) + scheduler_config.bdimx = min( + threads_per_cta, + max( + warp_size, + ceil_div(num_reductions, scheduler_config.vectorize_factor), + ), + ) + scheduler_config.bdimy = min( + threads_per_cta, + max(1, floor_div(threads_per_cta, scheduler_config.bdimx)), + ) + scheduler_config.godim = ceil_div( + num_iterations, scheduler_config.bdimy * scheduler_config.unroll_factor + ) + + # number of reduction elements not handled by a CTA + remaining_reduction = ceil_div( + num_reductions, + (scheduler_config.bdimx * scheduler_config.vectorize_factor), + ) + + if unroll_factor == 1 and remaining_reduction > 1: + # all remaining reduction goes to grdim + scheduler_config.grdim = remaining_reduction + yield scheduler_config + + # grid stride across reduction iterDomain is 1 + scheduler_config.grdim = 1 + yield scheduler_config + + def create_inputs(self, shape, tensor_datatype): + def inner_fn(num_inputs): + return [ + torch.randn(*shape, dtype=tensor_datatype, device="cuda") + for _ in range(num_inputs) + ] + + if self.selected_fusion == self.FUSION.ADD_SUM: + return inner_fn(num_inputs=4) + elif self.selected_fusion in [ + self.FUSION.INNER_SUM, + self.FUSION.TANH_SUM, + self.FUSION.EXP_SUM, + ]: + return inner_fn(num_inputs=1) + else: + assert False + + # A decorator to create a reduction fusion given some input arguments. + def create_fusion_func(self, inputs): + def sum_fusion(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.sum(T1, dims=[1], keepdim=False, dtype=DataType.Null) + T3 = fd.ops.cast(T2, dtype=DataType.BFloat16) + fd.add_output(T3) + + def add_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T2 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T3 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T4 = fd.ops.cast(T0, dtype=DataType.Float) + T5 = fd.ops.cast(T1, dtype=DataType.Float) + T6 = fd.ops.add(T4, T5) + T7 = fd.ops.cast(T2, dtype=DataType.Float) + T8 = fd.ops.add(T6, T7) + T9 = fd.ops.cast(T3, dtype=DataType.Float) + T10 = fd.ops.add(T8, T9) + T11 = fd.ops.sum(T10, dims=[1], keepdim=False, dtype=DataType.Null) + T12 = fd.ops.cast(T11, dtype=DataType.BFloat16) + fd.add_output(T12) + + def tanh_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.tanh(T1) + T3 = fd.ops.sum(T2, dims=[1], keepdim=False, dtype=DataType.Null) + T4 = fd.ops.cast(T3, dtype=DataType.BFloat16) + fd.add_output(T4) + + def exp_sum(fd: FusionDefinition) -> None: + T0 = fd.define_tensor( + shape=[-1, -1], + contiguity=[True, True], + dtype=DataType.BFloat16, + is_cpu=False, + stride_order=[1, 0], + ) + T1 = fd.ops.cast(T0, dtype=DataType.Float) + T2 = fd.ops.exp(T1) + T3 = fd.ops.sum(T2, dims=[1], keepdim=False, dtype=DataType.Null) + T4 = fd.ops.cast(T3, dtype=DataType.BFloat16) + fd.add_output(T4) + + if self.selected_fusion == self.FUSION.INNER_SUM: + return sum_fusion + elif self.selected_fusion == self.FUSION.ADD_SUM: + return add_sum + elif self.selected_fusion == self.FUSION.TANH_SUM: + return tanh_sum + elif self.selected_fusion == self.FUSION.EXP_SUM: + return exp_sum + else: + assert False + + # The pytorch eager mode reference used to validating nvfuser kernel. + def eager_reference(self, inputs): + def sum_fusion(inputs): + return torch.sum(inputs[0], dim=-1) + + def add_sum(inputs): + return torch.sum(inputs[0] + inputs[1] + inputs[2] + inputs[3], dim=-1) + + def tanh_sum(inputs): + return torch.sum(torch.tanh(inputs[0]), dim=-1) + + def exp_sum(inputs): + return torch.sum(torch.exp(inputs[0]), dim=-1) + + if self.selected_fusion == self.FUSION.INNER_SUM: + return sum_fusion(inputs) + elif self.selected_fusion == self.FUSION.ADD_SUM: + return add_sum(inputs) + elif self.selected_fusion == self.FUSION.TANH_SUM: + return tanh_sum(inputs) + elif self.selected_fusion == self.FUSION.EXP_SUM: + return exp_sum(inputs) + else: + assert False + + # Apply scheduler with custom parameters using decorator + def custom_scheduler(self, fd, scheduler_config): + def inner_fn(): + # Check if compatible with reduction scheduler + status, _ = fd.sched.can_schedule(SchedulerType.reduction) + assert status + + reduction_params = fd.sched.compute_reduction_heuristics() + + # Modify original parameters + if scheduler_config is not None: + self.convert_to_inner_reduction_params( + scheduler_config, reduction_params + ) + + # Schedule fusion + fd.sched.schedule() + + fd.schedule = inner_fn + return fd + + +# Run sequence of steps to collect data, train and test model +def main(): + # ====================== Setup Script Configuration ======================= + script_config = ScriptConfiguration( + num_dimensions=2, + outer_shapes=[16384], + inner_shapes=[128, 1024, 4096, 16384], + tensor_datatype=torch.bfloat16, + test_data_percentage=0.1, + empirical_batch_size=16384, + empirical_hidden_sizes=list(range(256, 32768, 256)), + ) + + autotune_config = AutotuneInnerReduction( + selected_fusion=AutotuneInnerReduction.FUSION.INNER_SUM + ) + + # ============================ Run Experiments ============================ + + parameters, performance = collect_data(script_config, autotune_config) + + # ============================ Separate Data ============================== + + train_data, test_data = separate_data(script_config, parameters, performance) + + # ========================= Train Regression Models ======================= + + # Apply machine learning regressor + # Given input shapes and scheduler parameters, predict performance metric. + from sklearn import ensemble + + train_inputs, train_perf = train_data + clf = ensemble.RandomForestRegressor() + clf = clf.fit(train_inputs, train_perf) + + # ========================= Test Regression Models ======================== + test_model_rmse(clf, script_config, autotune_config, test_data) + test_model(clf, script_config, autotune_config) + + +if __name__ == "__main__": + main() diff --git a/doc/dev/python_scheduling/autotune_utils.py b/doc/dev/python_scheduling/autotune_utils.py index 4017c6c87f8..e699f84270e 100644 --- a/doc/dev/python_scheduling/autotune_utils.py +++ b/doc/dev/python_scheduling/autotune_utils.py @@ -4,6 +4,7 @@ # Owner(s): ["module: nvfuser"] import torch +import math import itertools from nvfuser import FusionCache, FusionDefinition from dataclasses import dataclass, astuple @@ -13,6 +14,18 @@ # ============================================================================= +# Returns the result of a/b rounded to the nearest integer in the direction of +# positive infinity. +def ceil_div(a, b): + return int(math.ceil(a / b)) + + +# Returns the result of a/b rounded to the nearest integer in the direction of +# negative infinity. +def floor_div(a, b): + return int(math.floor(a / b)) + + @dataclass class ScriptConfiguration: # Settings for input tensor generation From 4382f28c78d3c2f16e17328ce7290b5b07f36637 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 11 Dec 2024 15:14:04 -0500 Subject: [PATCH 3/7] Add MatmulParams::cluster_dims parameter (#3574) Following #3557 we can specify the cluster size for our fusions. Currently we don't do anything explicitly with CGAs, but this can help guarantee that tiles are scheduled onto GPCs in pairs. Each GPC has a number of TPCs, each of which holds 2 SMs, so this lets us take advantage of caching at the TPC and GPC level for operand loads, in addition to L2. This PR enables this with a default size of `{2, 1, 1}` for the Hopper scheduler. The parameter is ignored in the Ampere scheduler. It is not yet plumbed into the heuristic plugin API yet. I thought maybe we should wait until we have more parameters related to CGAs to do that. --- csrc/scheduler/hopper_multi_matmul.cpp | 2 ++ csrc/scheduler/hopper_multi_matmul.h | 8 ++++++++ csrc/scheduler/matmul_heuristic.h | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index fbb95d46df2..a5b6f4d2bb7 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -53,6 +53,8 @@ void HopperMultipleMatmulScheduler::run() { inspectPrologues(); + setCGADims(); + scheduleOperands(); // schedule mma instruction output (mma_result) diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index bf7bc1df0f5..1d77785cc99 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -149,6 +149,14 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { std::vector> blockTileTensors( const std::vector& tvs); + //! Specifies the CGA dimensions by setting "cluster_dims" as fusion-managed + //! data + void setCGADims() const { + if (params_->cluster_dims != std::tuple{1, 1, 1}) { + fusion_->manage("cluster_dims", params_->cluster_dims); + } + } + //! Schedule the loads of all operands from global memory to shared memory. //! Starting from the basic tiled schedule, we swizzle the operand memory. //! Note that the cache op and LoadStoreOpType are already set during diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index 6a92d31fd2c..f66cd12e618 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -179,6 +179,10 @@ class MatmulParams : public HeuristicParams { //! axis and perform a grid reduction before the epilogue. int splitk_factor = 1; + //! This is the CGA size on Hopper+ devices. This parameter is ignored on + //! Ampere and Turing. + std::tuple cluster_dims = {2, 1, 1}; + std::string toString() const override { std::stringstream ss; ss << "\n===== Matmul Parameters ========\n" From d5af72fc90ffbf218d979ed8dff81be15750a52c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:27:53 -0500 Subject: [PATCH 4/7] Enable compilation in Hopper MMA test without input broadcasts (#3406) Stacked on #3410, #3414, and #3416 This simply enables compilation of the test which uses #3391. --- tests/cpp/test_matmul.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 88cc953c95d..cbd51d97cfb 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3819,9 +3819,9 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { Fusion fusion; FusionGuard fg(&fusion); - // constexpr int64_t M = 2048, N = 2048, K = 8192; + constexpr int64_t M = 2048, N = 2048, K = 8192; constexpr auto macro = MmaMacro::Hopper_64_256_16; - // constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N] + constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N] constexpr auto swizzle = MmaInputSmemSwizzle::B128; const auto dtype = DataType::Half; @@ -3954,7 +3954,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { // of 3 ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); inlineMost(); - tmp_fusion.printMath(); ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1); // The outermost loop dim of tv1c is a broadcast Mo axis, so @@ -3981,7 +3980,17 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { pred_checker.handle(kernel->topLevelExprs()); ASSERT_TRUE(pred_checker.found_mma); - // TODO: compile and run kernel once inlining is fixed + auto [A3d, B3d] = + matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); + at::Tensor A = A3d.squeeze(); + at::Tensor B = B3d.squeeze(); + std::vector inputs{A, B}; + + KernelExecutor ke; + ke.compile(&fusion, inputs, LaunchParams(), matmul_cparams); + auto cg_outputs = ke.run(inputs); + auto tref = atMatmul(A, B, layout); + EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } } // namespace nvfuser From 5716c09ee6759a5bbe1fd592121198eee9475450 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:29:13 -0500 Subject: [PATCH 5/7] Add MatmulParams::CircularBufferOptions::smem_circular_buffer_prefetch_gap (#3558) Previously we had hardcoded to use `smem_circular_buffer_stage - 1` but this gives us more flexibility in our heuristic. We will use `smem_circular_buffer_stage - smem_circular_buffer_prefetch_gap` as the prefetch distance. Parametrizing this way lets us modify the number of stages without needing to adjust the prefetch gap each time, since it will most commonly just be 1. --- csrc/scheduler/ampere_multi_matmul.cpp | 10 ++++++++-- csrc/scheduler/hopper_multi_matmul.cpp | 18 ++++++++++++++++-- csrc/scheduler/matmul_heuristic.h | 16 ++++++++++++++-- csrc/scheduler/matmul_heuristic_plugin.cpp | 4 ++++ csrc/scheduler/matmul_heuristic_plugin_api.h | 3 +++ tests/cpp/test_matmul_scheduler.cpp | 3 --- 6 files changed, 45 insertions(+), 9 deletions(-) diff --git a/csrc/scheduler/ampere_multi_matmul.cpp b/csrc/scheduler/ampere_multi_matmul.cpp index ee21e41ce8b..d582e9e9a10 100644 --- a/csrc/scheduler/ampere_multi_matmul.cpp +++ b/csrc/scheduler/ampere_multi_matmul.cpp @@ -1302,11 +1302,17 @@ void AmpereMultipleMatmulScheduler::setUpCircularBuffering() { for (TensorView* acw_smem : acw_smems_) { acw_smem->circularBuffer( - params_->circular_buffer_options.smem_circular_buffer_stage); + params_->circular_buffer_options.smem_circular_buffer_stage, + params_->circular_buffer_options.smem_circular_buffer_stage - + params_->circular_buffer_options + .smem_circular_buffer_prefetch_gap); } for (TensorView* bcw_smem : bcw_smems_) { bcw_smem->circularBuffer( - params_->circular_buffer_options.smem_circular_buffer_stage); + params_->circular_buffer_options.smem_circular_buffer_stage, + params_->circular_buffer_options.smem_circular_buffer_stage - + params_->circular_buffer_options + .smem_circular_buffer_prefetch_gap); } } diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index a5b6f4d2bb7..b2d8ec705ec 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -712,18 +712,32 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { params_->async_gmem_load_operands, "Circular buffer only supports async load"); } + NVF_CHECK( + params_->circular_buffer_options.smem_circular_buffer_prefetch_gap > + 0 && + params_->circular_buffer_options + .smem_circular_buffer_prefetch_gap <= + params_->circular_buffer_options.smem_circular_buffer_stage, + "smem_circular_buffer_prefetch_gap is ", + params_->circular_buffer_options.smem_circular_buffer_prefetch_gap, + " but is expected to be positive and not greater than number of stages: ", + params_->circular_buffer_options.smem_circular_buffer_stage); for (TensorView* acw_smem : acw_smems_) { acw_smem->circularBuffer( params_->circular_buffer_options.smem_circular_buffer_stage, /*prefetch_distance=*/ - params_->circular_buffer_options.smem_circular_buffer_stage - 1); + params_->circular_buffer_options.smem_circular_buffer_stage - + params_->circular_buffer_options + .smem_circular_buffer_prefetch_gap); } for (TensorView* bcw_smem : bcw_smems_) { bcw_smem->circularBuffer( params_->circular_buffer_options.smem_circular_buffer_stage, /*prefetch_distance=*/ - params_->circular_buffer_options.smem_circular_buffer_stage - 1); + params_->circular_buffer_options.smem_circular_buffer_stage - + params_->circular_buffer_options + .smem_circular_buffer_prefetch_gap); } } diff --git a/csrc/scheduler/matmul_heuristic.h b/csrc/scheduler/matmul_heuristic.h index f66cd12e618..7e8ee6dc4d7 100644 --- a/csrc/scheduler/matmul_heuristic.h +++ b/csrc/scheduler/matmul_heuristic.h @@ -41,10 +41,18 @@ class MatmulParams : public HeuristicParams { // greater than one. Otherwise it is ignored. int smem_circular_buffer_stage = 2; + // The circular buffering prefetch distance will be set to + // smem_circular_buffer_stage - smem_circular_buffer_prefetch_gap + // This value must be positive since the prefetch distance must be strictly + // less than the number of stages. + int smem_circular_buffer_prefetch_gap = 1; + bool operator==(const CircularBufferOptions& other) const { return other.circular_buffer_smem_write == circular_buffer_smem_write && other.circular_buffer_smem_read == circular_buffer_smem_read && - other.smem_circular_buffer_stage == smem_circular_buffer_stage; + other.smem_circular_buffer_stage == smem_circular_buffer_stage && + other.smem_circular_buffer_prefetch_gap == + smem_circular_buffer_prefetch_gap; } std::string toString() const { @@ -54,12 +62,16 @@ class MatmulParams : public HeuristicParams { << (circular_buffer_smem_write ? "true" : "false") << "\n" << " circular_buffer_smem_read: " << (circular_buffer_smem_read ? "true" : "false") << "\n" - << " smem_circular_buffer_stage: " << smem_circular_buffer_stage; + << " smem_circular_buffer_stage: " << smem_circular_buffer_stage + << "\n" + << " smem_circular_buffer_prefetch_gap: " + << smem_circular_buffer_prefetch_gap; return ss.str(); } size_t hash() const { return std::hash{}( + (static_cast(smem_circular_buffer_prefetch_gap) << 3) | (static_cast(smem_circular_buffer_stage) << 2) | (static_cast(circular_buffer_smem_write)) << 1) | (static_cast(circular_buffer_smem_read)); diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index b3821787b67..01333727841 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -135,6 +135,8 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { }; config->load_stages = mparams->circular_buffer_options.smem_circular_buffer_stage; + config->prefetch_gap = + mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap; config->async_gmem_load_operands = mparams->async_gmem_load_operands; setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); @@ -163,6 +165,8 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; + mparams->circular_buffer_options.smem_circular_buffer_prefetch_gap = + config->prefetch_gap; mparams->async_gmem_load_operands = config->async_gmem_load_operands; // Update mma macro if necessary to match provided instruction tile MmaMacroEncode menc(mparams->mma_macro); // this will record the family diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 224705530e5..207da96e9a8 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -74,6 +74,9 @@ struct KernelConfig { Tile instruction_tile = {16, 16, 16}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; + // The circular buffering prefetch distance will be set to + // load_stages - prefetch_gap + uint8_t prefetch_gap = 1; uint8_t grid_swizzle_factor = 0; uint8_t cta_order = 0; bool circular_buffer_smem_read = true; diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index fa80a096dce..3058ce59ad7 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3197,9 +3197,6 @@ class HopperMatmulSchedulerTest mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = true; mparams.circular_buffer_options.smem_circular_buffer_stage = 4; - - // TODO Create prefetch parameter - // mparams.circular_buffer_options.smem_circular_buffer_prefetch = 3; } void TearDown() { From 6e1919f72c8c4a3b7574cfb107e9d0b3c1ca224b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 12 Dec 2024 14:30:09 -0800 Subject: [PATCH 6/7] Fix ComputeAtMap for non-linear ID dependencies (#3577) Just patching ComputeAtMap to exclude dead expressions and vals. --- csrc/compute_at_map.cpp | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/csrc/compute_at_map.cpp b/csrc/compute_at_map.cpp index 18184c8149a..61c8263cddc 100644 --- a/csrc/compute_at_map.cpp +++ b/csrc/compute_at_map.cpp @@ -610,7 +610,9 @@ void IterDomainGraph::build(Fusion* fusion) { // Grab all the logical ids. for (auto consumer_tv : all_consumer_tvs) { - auto exprs = StmtSort::getExprsTo( + auto exprs = StmtSort::getExprsBetween( + {consumer_tv->getMaybeRootDomain().begin(), + consumer_tv->getMaybeRootDomain().end()}, {consumer_tv->getLogicalDomain().begin(), consumer_tv->getLogicalDomain().end()}); for (auto expr : exprs) { @@ -663,6 +665,20 @@ void IterDomainGraph::build(Fusion* fusion) { continue; } + // logical_id_uses are guaranteed to be a valid expr, but + // first_logical_id->definition() may not be part of the valid + // exprs + if (!prop_forward) { + if (std::any_of( + first_expr->inputs().begin(), + first_expr->inputs().end(), + [&](Val* id_input) { + return !all_ids_.has(id_input->as()); + })) { + continue; + } + } + if (visited_exprs.find(first_expr) != visited_exprs.end()) { continue; } @@ -1282,6 +1298,13 @@ void ComputeAtMap::buildUniqueExactExprMaps() { if (id->definition() != nullptr) { auto id_inputs = ir_utils::filterByType(id->definition()->inputs()); + // If any input ID is not included in the map, this definition + // should not be included either. + if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) { + return !idExistsInMap(id_input); + })) { + continue; + } if (std::any_of(id_inputs.begin(), id_inputs.end(), [&](auto id_input) { return disjoint_set_shared_ptr->has(id_input); })) { From 201a636f88bac357b7a4f15d4ed62feb0ee62163 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 12 Dec 2024 15:20:17 -0800 Subject: [PATCH 7/7] Implement basic split-k gemm for hopper matmul scheduler (#3575) This PR implements `scheduleSplitKSum` function to support split-k gemm with the hopper matmul schedule. - It support all operand formats such as TT, NT, TN, NN. --- csrc/device_lower/utils.cpp | 6 +- csrc/scheduler/hopper_multi_matmul.cpp | 89 +++++++++----------------- csrc/scheduler/hopper_multi_matmul.h | 5 ++ csrc/tensor_view.cpp | 13 +++- csrc/transform_rfactor.cpp | 5 +- tests/cpp/test_matmul_scheduler.cpp | 43 ++++++++++--- 6 files changed, 87 insertions(+), 74 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index a3d4323e761..35b825d5348 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -921,7 +921,11 @@ std::array getMmaLayout(const MmaOp* expr) { auto out_tv = ir_utils::getTv(expr->out()); IterDomain* reduction_id = nullptr; - for (auto id : out_tv->getLogicalDomain()) { + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. In the split-k case, use the root domain for the + // mma layout because the k dimension is divided into two iterDomains in the + // logical domain. + for (auto id : out_tv->getMaybeRootDomain()) { if (id->isReduction()) { reduction_id = id; break; diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index b2d8ec705ec..1efe75aeab2 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -29,6 +29,27 @@ namespace nvfuser { +void HopperMultipleMatmulScheduler::transformLikeMmaOutput( + TensorView* tv, + bool is_mma_result) { + // TODO Add constraints + + auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { + return (is_mma_result) ? idx - 1 : idx; + }; + + // Original: [..., Mo, No, Mi, Ni] + tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); + tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); + // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] + tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}}); + // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] + tv->merge(apply_k_dim_offset(-4)); + // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] + tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] +} + MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { ValGroup vg = graph_->toGroup(id); auto it = id_roles_.find(vg); @@ -397,22 +418,13 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { // do split-K rFactor to define splitk_sum and smem_epilogue if (params_->splitk_factor != 1) { - // TODO: schedule split-K - NVF_THROW("Hopper split-K is not yet tested"); // Note that the split-K split is already done in blockTileTensors TensorView* splitk_sum = mma_result->rFactor({-4, -1}); std::swap(splitk_sum, mma_result); splitk_sums_.push_back(splitk_sum); } - mma_result->split(-3, getM(params_->mma_macro)); - mma_result->split(-2, getN(params_->mma_macro)); - // [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki] - // -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] - mma_result->reorder({{-4, -3}}); - mma_result->merge(-5); - mma_result->axis(-4)->parallelize(ParallelType::TIDy); - + transformLikeMmaOutput(mma_result, /*is_mma_result=*/true); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( mma_result->getLoopDomain()); mma_result->setAllocationDomain(s.as(), true); @@ -509,17 +521,10 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Apply mma common transformation for (auto tv : {dc, d}) { - // [..., Mo, No, Mi, Ni] - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, getN(params_->mma_macro)); - // [..., Mo, No, Mio, Mii, Nio, Nii] - // -> [..., Mo, No, Mio, Nio, Mii, Nii] - tv->reorder({{-3, -2}}); - tv->merge(-4); + transformLikeMmaOutput(tv, /*is_mma_result=*/false); auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( tv->getLoopDomain()); tv->setLoopDomain(s.as()); - tv->axis(-5)->parallelize(ParallelType::TIDy); } d->axis(-1)->parallelize(ParallelType::Vectorize); } @@ -565,16 +570,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Apply mma common transformation for (auto tv : {dc, d_smem, d}) { - // Original: [..., Mo, No, Mi, Ni] - tv->split(-2, getM(params_->mma_macro)); - tv->split(-1, getN(params_->mma_macro)); - // After Split: [..., Mo, No, Mio, Mii, Nio, Nii] - tv->reorder({{-3, -2}}); - // After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii] - tv->merge(-4); - // After Merge: [..., Mo, No, Mio * Nio, Mii, Nii] - tv->axis(-3)->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii] + transformLikeMmaOutput(tv, /*is_mma_result=*/false); } // Schedule register cache; Output from epilogue @@ -643,41 +639,14 @@ void HopperMultipleMatmulScheduler::scheduleSplitKSum() { if (params_->splitk_factor == 1) { return; } - NVF_THROW("Split-K scheduling is not yet implemented for Hopper matmul"); for (TensorView* splitk_sum : splitk_sums_) { // Always use serial grid reduction for split-K sum splitk_sum->definition()->as()->requestSerialGridReduction(); - - if (params_->use_smem_epilogue) { - // Now that transforms are propagated backward to smem_epilogue, which - // is before splitk_sum, we can vectorize the inner-most non-trivial - // dimension of splitk_sum - // - // Note that the split-K reduction is the inner-most dimension. - Val* vec_ext = splitk_sum->axis(-2)->extent(); - NVF_ERROR(vec_ext->isConstInt()); - int64_t vec_ext_int = vec_ext->evaluate().as(); - splitk_sum->axis(-1)->parallelize(ParallelType::BIDz); - splitk_sum->axis(-3)->parallelize(ParallelType::TIDx); - if (vec_ext_int * dataTypeSize(splitk_sum->dtype()) > 16) { - // NOTE: We might encounter an illegal vectorization size if we are - // using Float for this reduction and Half for output. So here we - // first check whether the vectorize size is at most 16 bytes. If not, - // then we split into an unrolled loop that will do multiple - // vectorized reads/writes instead. Note that we reorder such that the - // axes are in order UR TIDx V. - splitk_sum->split( - -2, 16 / dataTypeSize(splitk_sum->dtype()), /*inner_split=*/true); - splitk_sum->axis(-3)->parallelize(ParallelType::Unroll); - splitk_sum->reorder({{-4, -3}}); - // In this case, we have [... iUR iTx rBz iS] - } - splitk_sum->reorder({{-2, -1}}); - } else { // no smem epilogue - // Reorder to place the split-K reduction next to innermost [... rBz iS] - splitk_sum->reorder({{-9, -2}}); - } - // Vectorize inner-most dimension [... (iUR iTx) rBz iV] + transformLikeMmaOutput(splitk_sum, /*is_mma_result=*/false); + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + splitk_sum->getLoopDomain()); + splitk_sum->setLoopDomain(s.as()); + splitk_sum->axis(2)->parallelize(ParallelType::BIDz); splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); } } diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 1d77785cc99..5eab0f4fbed 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -191,6 +191,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { // Return MatmulDimRole for IterDomain MatmulDimRole findMatmulDimRole(IterDomain* id); + // Schedule a block-tiled TensorView like mma output. + // Why? WGMMA has a unique output format. TensorViews after the mma-result in + // registers must respect this format for correctness. + void transformLikeMmaOutput(TensorView* tv, bool is_mma_result); + private: std::vector canonical_dim_ordering_; diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 654ff601ac7..14e0b8f746f 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -804,9 +804,12 @@ TensorView* TensorView::rFactor(const std::vector& axes) { "Error rfactoring ", this, " its definition is either a nullptr or not a reduction."); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. NVF_CHECK( - !domain()->hasRoot(), "Cannot call rfactor on the same view twice."); - + definition()->isA() || !domain()->hasRoot(), + "Cannot call rfactor on the same view twice."); NVF_CHECK( !definition()->isA(), "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); @@ -935,8 +938,12 @@ std::vector TensorView::rFactor( this, " its definition is either a nullptr or not a GroupedReductionOp or a multi-output reduction op."); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. NVF_CHECK( - !domain()->hasRoot(), "Cannot call rfactor on the same view twice."); + definition()->isA() || !domain()->hasRoot(), + "Cannot call rfactor on the same view twice."); NVF_CHECK( definition()->outputs().size() == tvs.size(), diff --git a/csrc/transform_rfactor.cpp b/csrc/transform_rfactor.cpp index 07799487eb0..709c5624935 100644 --- a/csrc/transform_rfactor.cpp +++ b/csrc/transform_rfactor.cpp @@ -340,7 +340,10 @@ std::pair TransformRFactor::runReplay( [](IterDomain* id) { return id->maybePartial(); }), "rFactor of partial domains not allowed, but at least one found."); - auto original_td_root = original_td->logical(); + // For hopper matmuls, the mma_result logical domain is reordered as [M, N, K] + // using commitLeafToLogical. Thus, the original logical domain is moved to + // the root domain. In this case, map from producer to consumer's root domain. + auto original_td_root = original_td->maybeRoot(); // Generate a new TensorDomain and set up map from one root to this one. std::vector new_producer_root(original_td_root.size(), nullptr); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 3058ce59ad7..838f96cc140 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3120,7 +3120,9 @@ using HopperMatmulSchedulerTestParams = std::tuple< int64_t, // M int64_t, // N int64_t, // K - MmaMacro>; + MmaMacro, + int64_t // SplitK Factor + >; std::string hopperTestName( const testing::TestParamInfo& info) { @@ -3129,8 +3131,16 @@ std::string hopperTestName( bool a_k_inner, b_k_inner; int64_t M, N, K; MmaMacro mma_macro; - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = - info.param; + int64_t splitk_factor; + std::tie( + use_smem_epilogue, + a_k_inner, + b_k_inner, + M, + N, + K, + mma_macro, + splitk_factor) = info.param; os << (a_k_inner ? "K" : "M"); os << (b_k_inner ? "K" : "N"); os << "_" << M << "_" << N << "_" << K; @@ -3138,6 +3148,9 @@ std::string hopperTestName( if (use_smem_epilogue) { os << "_tma_store"; } + if (splitk_factor > 1) { + os << "_splitk_" << splitk_factor; + } return os.str(); } @@ -3162,8 +3175,15 @@ class HopperMatmulSchedulerTest void SetUp() { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0); - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = - GetParam(); + std::tie( + use_smem_epilogue, + a_k_inner, + b_k_inner, + M, + N, + K, + mma_macro, + splitk_factor) = GetParam(); if (a_k_inner) { layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT; @@ -3192,11 +3212,12 @@ class HopperMatmulSchedulerTest mparams.use_smem_epilogue = use_smem_epilogue; + mparams.splitk_factor = splitk_factor; mparams.tile_sizes = gemm_tile; mparams.async_gmem_load_operands = true; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = true; - mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_stage = 2; } void TearDown() { @@ -3215,7 +3236,8 @@ class HopperMatmulSchedulerTest KernelExecutor ke; ke.compile(fusion, inputs, LaunchParams(), matmul_cparams); auto nvf_out = ke.run(inputs); - EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-5, 1e-5)); + // NOTE Relax tolerances for split-k case + EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3)); } protected: @@ -3223,6 +3245,7 @@ class HopperMatmulSchedulerTest bool a_k_inner, b_k_inner; int64_t M, N, K; MmaMacro mma_macro; + int64_t splitk_factor; std::unique_ptr fusion_up; Fusion* fusion; std::unique_ptr fusion_guard; @@ -3304,7 +3327,8 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(512), // M testing::Values(256), // N testing::Values(64), // K - testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros + testing::Values(MmaMacro::Hopper_64_128_16), // mma_macros + testing::Values(1, 2) // SplitK Factor ), hopperTestName); @@ -3323,7 +3347,8 @@ INSTANTIATE_TEST_SUITE_P( MmaMacro::Hopper_64_128_16, MmaMacro::Hopper_64_64_16, MmaMacro::Hopper_64_32_16, - MmaMacro::Hopper_64_16_16) // mma_macros + MmaMacro::Hopper_64_16_16), // mma_macros + testing::Values(1) // SplitK Factor ), hopperTestNameSwizzle);