From c14d4181ae499c90b5728c06dfac248a5afab8b0 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 29 Oct 2024 22:07:09 -0700 Subject: [PATCH] Profile configurations for InnerOuterPersistent scheduler in python frontend (#3118) # Summary This PR explores auto-tuning a `LayerNormBackward` fusion using the `InnerOuterPersistent` scheduler in the python-frontend. - Create `autotune_persistent.py` to test several parameter configurations then apply `DecisionTreeRegressor` - The selected performance metric is `effective_bandwidth_gbs`. The empirical scheduler selects the configuration that has the highest predicted `effective_bandwidth_gbs`. # Key differences from approach for `Pointwise` scheduler - `vectorize_factor`, `thread_per_block_min`, and `thread_per_block_max` are specified before running `computeHeuristics`. These settings are akin to hyper-parameters used to constrain the generated scheduler parameters. - Create `SchedulerHyperParameters` as an entry in `HeuristicDataCache` to specify these constraints when generating scheduler parameters. # Details 1. Create `struct SchedulerHyperParameters` in `csrc/scheduler/utils.h` 2. Create `HeuristicDataCacheEntry` in `csrc/scheduler/compile_time_info.h` 3. Modify `computeHeuristics` to use hyper-parameter constraints. 4. Expose `SchedulerHyperParameters` in python frontend. 5. Allow user schedulers to define a `HeuristicDataCache` during scheduling. * `ScheduleHyperParameters` contains parameters for `vectorize_factor`, `unroll_factor`, `threads_per_block_min`, and `threads_per_block_max`. --- csrc/python_frontend/fusion_cache.cpp | 3 +- csrc/python_frontend/fusion_cache.h | 4 + csrc/python_frontend/fusion_definition.cpp | 4 + csrc/python_frontend/python_bindings.cpp | 63 ++- csrc/scheduler/compile_time_info.h | 12 +- csrc/scheduler/normalization_inner_outer.cpp | 75 +++- csrc/scheduler/registry.cpp | 2 + csrc/scheduler/utils.h | 28 ++ .../python_scheduling/autotune_persistent.py | 417 ++++++++++++++++++ 9 files changed, 590 insertions(+), 18 deletions(-) create mode 100644 doc/dev/python_scheduling/autotune_persistent.py diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 53dc43bdbe8..83ce851dbab 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -227,7 +227,8 @@ HeuristicParams* UserSchedule::computeHeuristics(SchedulerType scheduler_type) { NVF_CHECK( heuristic_params == nullptr, "Heuristic Scheduler is already defined for this UserSchedule"); - heuristic_params = scheduler->computeHeuristics(fusion(), runtime_info_ref); + heuristic_params = scheduler->computeHeuristics( + fusion(), runtime_info_ref, data_cache.get()); return heuristic_params.get(); } diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index b4283b7bdaf..190671b2b82 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -33,6 +34,9 @@ struct UserSchedule { //! The parameters for scheduler heuristic. std::unique_ptr heuristic_params; + //! The compile-time data cache. + std::unique_ptr data_cache; + //! Concretized, Scheduled Fusion IR std::unique_ptr scheduled_fusion; diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index b512d9d761b..09648a0bf36 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -239,6 +240,9 @@ void FusionDefinition::setupSchedule( user_sched_ = fusionCache()->createUserSchedule( scheds, inputs, device, overwrite_existing_schedule); + // Create scheduler data cache + user_sched_->data_cache = std::make_unique(); + // Building a new Fusion container for scheduling with definition such that // the definition's tensor data members refer to the corresponding IR objects // needed for scheduling. A simple copy of the container would mean the data diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d7b9e8d1c34..b229107c45b 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -779,6 +781,44 @@ void initNvFuserPythonBindings(PyObject* module) { defineHeuristicParamBindings(nvfuser); + py::class_ hyperparameters( + nvfuser, "SchedulerHyperParameters"); + hyperparameters.def(py::init()); + hyperparameters.def_property( + "vectorize_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.vectorize_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t vectorize_factor_) { + self.vectorize_factor = vectorize_factor_; + }); + hyperparameters.def_property( + "unroll_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.unroll_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t unroll_factor_) { self.unroll_factor = unroll_factor_; }); + hyperparameters.def_property( + "threads_per_block_min", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_min; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_min_) { + self.threads_per_block_min = threads_per_block_min_; + }); + hyperparameters.def_property( + "threads_per_block_max", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_max; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_max_) { + self.threads_per_block_max = threads_per_block_max_; + }); + //! KernelProfiles are encapsulated in FusionProfiles where each KP //! is associated with a segment. py::class_ kernel_prof(nvfuser, "KernelProfile"); @@ -1401,7 +1441,7 @@ void initNvFuserPythonBindings(PyObject* module) { py::class_ nvf_ops(fusion_def, "Operators"); nvf_ops.def(py::init()); - // ******************** INSERT OP BINDINGS BELOW HERE ******************** +// ******************** INSERT OP BINDINGS BELOW HERE ******************** #define OP_PREFIX "Operators." #define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ nvf_ops.def( \ @@ -3822,6 +3862,27 @@ void initNvFuserPythonBindings(PyObject* module) { return *parameters->as(); }, py::return_value_policy::reference); + nvf_sched.def( + "schedule_hyperparameters", + [](FusionDefinition::SchedOperators& self) + -> scheduler_utils::SchedulerHyperParameters& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>( + sched->data_cache.get(), []() { + return std::make_unique< + scheduler_utils::SchedulerHyperParameters>( + /*vectorize_factor=*/1, + /*unroll_factor=*/1, + /*threads_per_block_min=*/1, + /*threads_per_block_max=*/1); + }); + return scheduler_hyperparameters_entry.get(); + }, + py::return_value_policy::reference); } void cleanup() { diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index 18b5efb0e8e..d413c99ae81 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -46,7 +46,8 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, - VECTORIZATION_BREAK_POINT_OF_RED_PROD + VECTORIZATION_BREAK_POINT_OF_RED_PROD, + SCHEDULE_HYPERPARAMETERS }; //! Entry type definition class for `DOMAIN_MAP`, @@ -203,6 +204,15 @@ class VectorizationBreakPointOfReductionProducer { CompileTimeEntryType::VECTORIZATION_BREAK_POINT_OF_RED_PROD; }; +//! Entry type definition class for `SCHEDULE_HYPERPARAMETERS`, +//! stores hyperparameters for SchedulerEntry::computeHeuristics +class SchedulerHyperParameters { + public: + using DataType = scheduler_utils::SchedulerHyperParameters; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::SCHEDULE_HYPERPARAMETERS; +}; + //! Base abstract class for unified storage in `HeuristicDataCache`, //! each entry in `HeuristicDataCache` will be a subclass. class CompileTimeInfoBase : public PolymorphicBase { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 6dd34f4cab9..2ea854f0a88 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -186,7 +186,9 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( SchedulerRuntimeInfo& runtime_info, HeuristicDataCache* data_cache, const std::vector& reduction_tvs, - const int64_t vectorize_factor) { + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max) { FUSER_PERF_SCOPE( "normalization_inner_outer::getPersistentBufferStorageParams"); @@ -230,9 +232,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( const auto dev_prop = at::cuda::getCurrentDeviceProperties(); int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( - fusion, - reduction_tvs, - InnerOuterPersistentKernelScheduler::threads_per_block_max); + fusion, reduction_tvs, threads_per_block_max); int64_t available_smem = (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; int64_t available_regs = scheduler_utils::register_file_size_56k; @@ -281,8 +281,8 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( tv_buffer_size_regs, dataTypeSize(current_tv->getDataType().value()), vectorize_factor, - InnerOuterPersistentKernelScheduler::threads_per_block_min, - InnerOuterPersistentKernelScheduler::threads_per_block_max, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); buffer_params.smem_buffer_size += tv_buffer_size_smem; @@ -332,6 +332,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( const int64_t outer_dim_numel, const int64_t persistent_buffer_size, const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const int64_t warp_size) { // if inner_dim_numel <= 1024, we are doing multiple reductions per block // with a constant batch size of 1 if vectorized. See Step 5 of @@ -380,11 +382,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( }; const int64_t after_vectorization = inner_dim_numel / vectorize_factor; - const int64_t threads_per_block_min = std::min( - after_vectorization, - InnerOuterPersistentKernelScheduler::threads_per_block_min); - const int64_t threads_per_block_max = - InnerOuterPersistentKernelScheduler::threads_per_block_max; + const int64_t threads_per_block_min_after_vectorization = + std::min(after_vectorization, threads_per_block_min); const int64_t batch_min = getMinimumBatch(); const int64_t batch_max = getMaximumInnerOuterPersistentBufferBatch(); @@ -392,7 +391,7 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( // is larger than batch_max, try increase threads per block by a warp until // the threads_per_block reaches threads_per_block_max or the batch size // reaches batch_min. - int64_t threads_per_block = threads_per_block_min; + int64_t threads_per_block = threads_per_block_min_after_vectorization; int64_t inner_batch = ceilDiv(after_vectorization, threads_per_block); while (inner_batch > batch_max && threads_per_block + warp_size <= threads_per_block_max && @@ -432,6 +431,8 @@ std::unique_ptr innerOuterPersistentHeuristic( const int64_t smem_overhead, const size_t tmp_gmem_dtype_size, const size_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const bool project_to_input, const PrimDataType index_type) { auto rparams = std::make_unique( @@ -512,6 +513,8 @@ std::unique_ptr innerOuterPersistentHeuristic( outer_dim_numel, regs_buffer_size, iop.inner_vect, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); iop.inner_batch = persistent_batch; @@ -743,12 +746,32 @@ std::unique_ptr getInnerOuterPersistentHeuristics( scheduler_utils::persistentBuffers(fusion)); }); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + auto& persistent_buffer_info = persistent_buffer_info_entry.get(); NVF_ERROR( !persistent_buffer_info.persistent_buffers.empty(), "Persistent scheduler requires persistent buffers."); auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); std::unique_ptr rparams = innerOuterPersistentHeuristic( properties.total_iteration_numel, @@ -757,7 +780,9 @@ std::unique_ptr getInnerOuterPersistentHeuristics( buffer_params.smem_buffer_size, buffer_params.smem_overhead, max_outer_reduction_dtype_size, - vectorize_factor, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max, buffer_params.project_to_input, runtime_info.getIndexType()); @@ -1244,9 +1269,29 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( data_cache, (int)(reduced_tv->nDims() - properties.inner_most_dimension_ndims)); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + // check if there is enough register and shared memory for persistence const auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 32cb0aa30f0..3f6b5827db5 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -224,4 +224,6 @@ template class HeuristicDataCacheEntry< template class HeuristicDataCacheEntry; template class HeuristicDataCacheEntry< HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>; } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 5dab953dead..77317cde31b 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -173,6 +173,34 @@ inline void parallelizeAllLike( propagate_padding); } +// Common hyperparameters used in heuristic scheduler. These hyperparameters +// are passed to SchedulerEntry::computeHeuristics through the +// HeuristicDataCache. These hyperparameters alter the generation of the +// HeuristicParams for the scheduler. +struct SchedulerHyperParameters { + SchedulerHyperParameters( + int64_t vectorize_factor_, + int64_t unroll_factor_, + int64_t threads_per_block_min_, + int64_t threads_per_block_max_) + : vectorize_factor(vectorize_factor_), + unroll_factor(unroll_factor_), + threads_per_block_min(threads_per_block_min_), + threads_per_block_max(threads_per_block_max_) {} + + //! Number of elements to load per vectorize load. + int64_t vectorize_factor = 1; + + //! Number of iterations to unroll for-loop. + int64_t unroll_factor = 1; + + //! Minimum number of threads per block. + int64_t threads_per_block_min = 1; + + //! Maximum number of threads per block. + int64_t threads_per_block_max = 1; +}; + struct PersistentBufferInfo { std::vector persistent_buffers; std::unordered_set unmappable_dims; diff --git a/doc/dev/python_scheduling/autotune_persistent.py b/doc/dev/python_scheduling/autotune_persistent.py new file mode 100644 index 00000000000..9cb02c6c0e7 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_persistent.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +import random +from nvfuser import FusionCache, FusionDefinition, SchedulerType, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from copy import deepcopy + +# ============================ Description ============================ + +# 1. Define a nvfuser fusion and its pytorch eager mode reference. +# +# 2. Profile the CUDA kernel performance by iterating over a set of input +# arguments and scheduler configurations. +# +# 3. Train a regression model to predict the desired performance metric given +# some input arguments and a scheduler configuration. +# +# 4. Measure the performance of the regression model. +# - Calculate RMSE of predicted and actual performance on test set. +# - Find the configuration with the best performance using regression model. +# Then, compare against the heuristic configuration selected by nvfuser. +# - For a specific batch size, gather performance across a range of hidden +# sizes. Calculate performance for best predicted and nvfuser +# configurations. Plot a chart comparing performance using matplotlib. + +# The selected performance metric is effective_bandwidth_gbs. The empirical +# scheduler selects the configuration that has the highest predicted +# effective_bandwidth_gbs. + +# ============================ Configurations ============================ + +# Settings for input tensor generation +num_dimensions = 2 +outer_shapes = [256, 1024, 4096, 16384] +inner_shapes = [2**i for i in range(10, 15)] + +# For pointwise scheduler, we test the cartesian product of vectorization and +# cta_size factors. +parameter_configurations = [ + vectorize_range := [1, 2, 4, 8], + threads_per_cta_range := list(range(128, 288, 32)), +] + +# We profile a range of input shapes with various configurations. +# This argument determines how much of the profiled data to keep as a test set. +test_data_percentage = 0.1 + +# The selected batch size for empirical and nvfuser comparison. +empirical_batch_size = 512 + +# The range of hidden sizes for empirical and nvfuser comparision. +empirical_hidden_sizes = list(range(1024, 28672, 256)) + +# NOTE For 24gb memory limit +# empirical_hidden_sizes = list(range(256, 22784, 256)) + + +def create_inputs(shape): + """Create input arguments for nvfuser fusion and eager mode""" + a = torch.randn(*shape, dtype=torch.bfloat16, device="cuda", requires_grad=True) + grads = torch.randn(*shape, dtype=torch.bfloat16, device="cuda") + weights = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + bias = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + + eps = 1e-5 + mean = a.to(torch.float).mean(dim=-1) + variance = a.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(1) + + nvf_inputs = [a, grads, mean, invstd, weights] + eager_inputs = [a, weights, bias, grads] + return nvf_inputs, eager_inputs + + +# A decorator to create a pointwise fusion given some input arguments. +def create_fusion_func(inputs): + PROMOTE_DTYPES = [DataType.BFloat16, DataType.Half] + dtype = torch_dtype_to_nvfuser_dtype(inputs[0].dtype) + + def fusion_func(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + T1 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + + T2 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False + ) + T3 = fd.define_tensor( + shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False + ) + + T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + + if dtype in PROMOTE_DTYPES: + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.cast(T4, dtype=DataType.Float) + + V8 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) + T9 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + V12 = T0.shape() + T13 = fd.ops.broadcast_in_dim(T9, shape=V12, broadcast_dims=[0, 1]) + T14 = fd.ops.sub(T0, T13) + + T18 = fd.ops.broadcast_in_dim(T3, shape=V12, broadcast_dims=[0, 1]) + T19 = fd.ops.mul(T14, T18) + + T23 = fd.ops.broadcast_in_dim(T4, shape=V12, broadcast_dims=[1]) + T28 = fd.ops.sum(T1, dims=[0], keepdim=False, dtype=DataType.Null) + + T30 = fd.ops.mul(T1, T23) + T31 = fd.ops.mul(T1, T19) + T32 = fd.ops.sum(T31, dims=[0], keepdim=False, dtype=DataType.Null) + + T34 = fd.ops.mul(T30, T18) + T35 = fd.ops.mul(T30, T14) + T36 = fd.ops.sum(T35, dims=[1], keepdim=False, dtype=DataType.Null) + + T40 = fd.ops.broadcast_in_dim(T36, shape=V8, broadcast_dims=[0]) + T41 = fd.ops.neg(T34) + T42 = fd.ops.sum(T41, dims=[1], keepdim=False, dtype=DataType.Null) + T46 = fd.ops.broadcast_in_dim(T42, shape=V8, broadcast_dims=[0]) + S47 = fd.define_scalar(-0.500000, dtype=DataType.Double) + T48 = fd.ops.mul(S47, T40) + S49 = fd.define_scalar(3.00000, dtype=DataType.Double) + T50 = fd.ops.pow(T3, S49) + T51 = fd.ops.mul(T48, T50) + T54 = fd.ops.sum(T46, dims=[1], keepdim=False, dtype=DataType.Null) + T55 = fd.ops.sum(T51, dims=[1], keepdim=False, dtype=DataType.Null) + + T59 = fd.ops.broadcast_in_dim(T55, shape=V8, broadcast_dims=[0]) + T63 = fd.ops.broadcast_in_dim(T59, shape=V12, broadcast_dims=[0, 1]) + T67 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + T71 = fd.ops.broadcast_in_dim(T67, shape=V12, broadcast_dims=[0, 1]) + + S72 = fd.define_scalar(2.00000, dtype=DataType.Double) + T73 = fd.ops.mul(S72, T63) + T74 = fd.ops.sub(T0, T71) + T75 = fd.ops.mul(T73, T74) + + S77 = fd.ops.reciprocal(T0.size(1)) + T78 = fd.ops.mul(T75, S77) + T82 = fd.ops.broadcast_in_dim(T54, shape=V8, broadcast_dims=[0]) + T86 = fd.ops.broadcast_in_dim(T82, shape=V12, broadcast_dims=[0, 1]) + T88 = fd.ops.mul(S77, T86) + T89 = fd.ops.add(T78, T88) + T90 = fd.ops.add(T34, T89) + + if dtype in PROMOTE_DTYPES: + T28 = fd.ops.cast(T28, dtype=dtype) + T90 = fd.ops.cast(T90, dtype=dtype) + T32 = fd.ops.cast(T32, dtype=dtype) + + fd.add_output(T90) + fd.add_output(T32) + fd.add_output(T28) + + return fusion_func + + +# The pytorch eager mode reference used to validating nvfuser kernel. +def eager_reference(inputs): + inputs_cloned = deepcopy(inputs) + a, weights, bias, grad_output = inputs_cloned + eager_output = torch.nn.functional.layer_norm( + a.to(torch.double), + a.shape[1:], + weight=weights.to(torch.double), + bias=bias.to(torch.double), + ) + grad_output = grad_output.to(torch.double) + eager_output.backward(grad_output) + return [a.grad, weights.grad, bias.grad] + + +# ============================ Function Definitions ============================ + + +# Apply scheduler with custom parameters using decorator +def custom_persistent_scheduler(fd, config): + def inner_fn(): + # Check if compatible with persistent scheduler + status, _ = fd.sched.can_schedule(SchedulerType.inner_outer_persistent) + assert status + + # Modify original parameters + if config is not None: + hyperparameters = fd.sched.schedule_hyperparameters() + vectorize_factor, threads_per_block = config + hyperparameters.vectorize_factor = vectorize_factor + hyperparameters.threads_per_block_min = threads_per_block + hyperparameters.threads_per_block_max = threads_per_block + + # Schedule fusion + fd.sched.schedule(SchedulerType.inner_outer_persistent) + + fd.schedule = inner_fn + return fd + + +# Apply schedule decorator, run fusion, and profile performance +def run_profile(presched_fd, nvf_inputs, eager_inputs, config=None): + scheduled_fd = custom_persistent_scheduler(presched_fd, config) + nvf_outputs = scheduled_fd.execute(nvf_inputs, profile=True) + + # validate correctness + ref_outputs = eager_reference(eager_inputs) + for nvf_out, ref_out in zip(nvf_outputs, ref_outputs): + assert torch.allclose(nvf_out, ref_out, atol=1e-1, rtol=1e-1) + + prof = scheduled_fd.profile() + bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs + time = prof.kernel_profiles[0].time_ms + return bandwidth, time + + +def argmax(map_config_to_perf): + best_perf = -1 + best_config = None + for config, perf in map_config_to_perf.items(): + if perf > best_perf: + best_perf = perf + best_config = config + return best_config + + +# Given a prediction model, input_shape, and set of parameter configurations, +# find the best parameters +def find_best_parameters(predictor, input_shape, parameter_configurations): + map_config_to_performance = { + config: predictor.predict([[*input_shape, *config]]) + for config in itertools.product(*parameter_configurations) + } + return argmax(map_config_to_performance) + + +# ============================ Run Experiments ================================ + +# Collect data for decision tree +parameters = [] +performance = [] + +for shape in itertools.product(outer_shapes, inner_shapes): + print(shape) + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + # vectorization and threads_per_cta configurations + for config in itertools.product(*parameter_configurations): + perf_metric, _ = run_profile(presched_fd, nvf_inputs, eager_inputs, config) + parameters.append((*shape, *config)) + performance.append(perf_metric) + +# ============================ Separate Data ================================== + +# Separate collected data into training and test sets +train_data = [] +test_data = [] +train_perf = [] +test_perf = [] +test_shapes = set() +all_test_config = {} # key: input_shape, value: (config, perf) + +for data, perf in zip(parameters, performance): + shape = data[:num_dimensions] + config = data[num_dimensions:] + + if shape in all_test_config: + all_test_config[shape][config] = perf + else: + all_test_config[shape] = {config: perf} + + if random.random() < test_data_percentage: + test_data.append(data) + test_perf.append(perf) + else: + test_shapes.add(shape) + train_data.append(data) + train_perf.append(perf) + +# key: input_shape, value: best_config +best_test_config = {shape: argmax(all_test_config[shape]) for shape in test_shapes} + +# ========================= Train Regression Models =========================== + +# Apply decision tree regressor +# Given input shapes and scheduler parameters, predict performance metric. +from sklearn import tree + +clf = tree.DecisionTreeRegressor() +clf = clf.fit(train_data, train_perf) +test_pred = clf.predict(test_data) + +print("===================== measure performance rmse ========================") + +# Estimate prediction error with RMSE +import numpy as np + +test_perf = np.array(test_perf) +print( + "Test prediction error (RMSE)", + np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), +) +print("Test performance", test_perf) +print("Test prediction", test_pred) + +print("======================= compare configurations =======================") +# Find best configuration for test_shapes +print( + "input shape, estimate_config:(vectorization, cta_size), actual_config:(vectorization, cta_size), correct" +) +correctness_count = 0 +mismatch_configs = [] +for shape in test_shapes: + estimate_config = find_best_parameters(clf, shape, parameter_configurations) + + match_config = estimate_config == best_test_config[shape] + if not match_config: + mismatch_configs.append((shape, estimate_config)) + + correctness_count += int(match_config) + print(f"{shape}, {estimate_config}, {best_test_config[shape]}, {match_config}") +print("% of predictions match nvfuser parameters", correctness_count / len(test_shapes)) +print(correctness_count, "out of", len(test_shapes)) + +print("======================= compare performance =========================") + +for shape, estimate_config in mismatch_configs: + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_perf = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + _, nvf_perf = run_profile(presched_fd, nvf_inputs, eager_inputs) + est_perf_faster = est_perf < nvf_perf + print( + f"{shape} \t estimate_perf:{est_perf:.5f} \t nvfuser_perf:{nvf_perf:.5f} \t is_estimated_config_faster:\t{est_perf_faster}" + ) + +print("=====================================================================") + +# For a specific batch size, gather performance across a range of hidden sizes. +# Calculate performance for best predicted and nvfuser configurations. Plot a +# chart comparing performance using matplotlib. + +# NOTE: The matplotlib experiment plots the kernel runtime, which could be +# different than the selected performance metric. Currently, the performance +# metric is effective_bandwidth_gbs. + +import matplotlib.pyplot as plt +import numpy as np + +# Avoid reusing any cached, user-scheduled fusions to have a clean run. +FusionCache.reset() +est_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + est_perfs.append(est_time_ms) + print( + f"decision tree: {empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms:.3f}" + ) + +FusionCache.reset() +nvf_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, nvf_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs) + nvf_perfs.append(nvf_time_ms) + print(f"nvfuser: {empirical_batch_size}, {hidden_shape}, {nvf_time_ms:.3f}") + +# Get mean speed-up from nvfuser to empirical configurations across all input shapes. +# Negative value mean empirical configurations are slower than nvfuser. +print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) + +np_hidden_size = np.array(empirical_hidden_sizes) +plt.plot(np_hidden_size, np.array(est_perfs)) +plt.plot(np_hidden_size, np.array(nvf_perfs)) + +plt.xlabel("Hidden Size") +plt.ylabel("Time(ms)") +plt.title( + f"Batch Size = {empirical_batch_size}, Compare Decision Tree Heuristic vs NvFuser" +) +plt.legend(["decision_tree", "nvfuser"], loc="lower right") +plt.savefig(f"persistent_inner_outer_empirical_batchsize{empirical_batch_size}.png") + +# =============================================================================