From e4c98249f4606201eddda7a7106d1cd372ea9aa0 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:37:41 -0400 Subject: [PATCH 01/15] fix padded bdimx to use warp reduction in inner reduction scheduler (#3288) Simple fix to padded bdimx in inner reduction scheduler. **Performance changes:** **(1) H100** 5 lines corresponds to batch size of 16, 512, 2048, 8192, 16384 ![image](https://github.com/user-attachments/assets/225e2ed1-0c99-402d-acce-894f728e083b) **(2) A100** ![image](https://github.com/user-attachments/assets/0658a933-dc5e-4cf3-af48-4b95289b0977) --- csrc/scheduler/reduction.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 87f9d2bffad..4032c9d159e 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -388,9 +388,6 @@ std::unique_ptr innerReductionHeuristic( bool pad_bdimx = bdimx > 16 && bdimx * bdimy < (int64_t)at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; - // If barely just covering reduction dim, don't pad to the next warp - pad_bdimx = pad_bdimx && - bdimx * inner_reduction_unroll_factor != inner_most_dimension_numel; rparams->pad_inner_reduction_to_warp = pad_bdimx; if (rparams->pad_inner_reduction_to_warp) { From 7220207cae9ea9ab79814d4b5d6d8532fa629875 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:36:19 -0700 Subject: [PATCH 02/15] Remove deprecated `clear_cuda_cache` (#3306) This benchmark was added recently and did not have the changes added by PR #3252. The benchmark will fail on the CI due to missing import function --- benchmarks/python/test_adaptive_layernorm_host.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/benchmarks/python/test_adaptive_layernorm_host.py b/benchmarks/python/test_adaptive_layernorm_host.py index 7e3c67b6d8b..7e60da7659d 100644 --- a/benchmarks/python/test_adaptive_layernorm_host.py +++ b/benchmarks/python/test_adaptive_layernorm_host.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest from nvfuser import FusionDefinition, DataType -from nvfuser.pytorch_utils import clear_cuda_cache from .core import run_benchmark import torch @@ -73,8 +72,6 @@ def test_adaptive_layernorm_fwd_benchmark( disable_validation: bool, disable_benchmarking: bool, ): - clear_cuda_cache() - B = 1 T = 30 * 1024 D = 1024 From c14d4181ae499c90b5728c06dfac248a5afab8b0 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Tue, 29 Oct 2024 22:07:09 -0700 Subject: [PATCH 03/15] Profile configurations for InnerOuterPersistent scheduler in python frontend (#3118) # Summary This PR explores auto-tuning a `LayerNormBackward` fusion using the `InnerOuterPersistent` scheduler in the python-frontend. - Create `autotune_persistent.py` to test several parameter configurations then apply `DecisionTreeRegressor` - The selected performance metric is `effective_bandwidth_gbs`. The empirical scheduler selects the configuration that has the highest predicted `effective_bandwidth_gbs`. # Key differences from approach for `Pointwise` scheduler - `vectorize_factor`, `thread_per_block_min`, and `thread_per_block_max` are specified before running `computeHeuristics`. These settings are akin to hyper-parameters used to constrain the generated scheduler parameters. - Create `SchedulerHyperParameters` as an entry in `HeuristicDataCache` to specify these constraints when generating scheduler parameters. # Details 1. Create `struct SchedulerHyperParameters` in `csrc/scheduler/utils.h` 2. Create `HeuristicDataCacheEntry` in `csrc/scheduler/compile_time_info.h` 3. Modify `computeHeuristics` to use hyper-parameter constraints. 4. Expose `SchedulerHyperParameters` in python frontend. 5. Allow user schedulers to define a `HeuristicDataCache` during scheduling. * `ScheduleHyperParameters` contains parameters for `vectorize_factor`, `unroll_factor`, `threads_per_block_min`, and `threads_per_block_max`. --- csrc/python_frontend/fusion_cache.cpp | 3 +- csrc/python_frontend/fusion_cache.h | 4 + csrc/python_frontend/fusion_definition.cpp | 4 + csrc/python_frontend/python_bindings.cpp | 63 ++- csrc/scheduler/compile_time_info.h | 12 +- csrc/scheduler/normalization_inner_outer.cpp | 75 +++- csrc/scheduler/registry.cpp | 2 + csrc/scheduler/utils.h | 28 ++ .../python_scheduling/autotune_persistent.py | 417 ++++++++++++++++++ 9 files changed, 590 insertions(+), 18 deletions(-) create mode 100644 doc/dev/python_scheduling/autotune_persistent.py diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 53dc43bdbe8..83ce851dbab 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -227,7 +227,8 @@ HeuristicParams* UserSchedule::computeHeuristics(SchedulerType scheduler_type) { NVF_CHECK( heuristic_params == nullptr, "Heuristic Scheduler is already defined for this UserSchedule"); - heuristic_params = scheduler->computeHeuristics(fusion(), runtime_info_ref); + heuristic_params = scheduler->computeHeuristics( + fusion(), runtime_info_ref, data_cache.get()); return heuristic_params.get(); } diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index b4283b7bdaf..190671b2b82 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -33,6 +34,9 @@ struct UserSchedule { //! The parameters for scheduler heuristic. std::unique_ptr heuristic_params; + //! The compile-time data cache. + std::unique_ptr data_cache; + //! Concretized, Scheduled Fusion IR std::unique_ptr scheduled_fusion; diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index b512d9d761b..09648a0bf36 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -239,6 +240,9 @@ void FusionDefinition::setupSchedule( user_sched_ = fusionCache()->createUserSchedule( scheds, inputs, device, overwrite_existing_schedule); + // Create scheduler data cache + user_sched_->data_cache = std::make_unique(); + // Building a new Fusion container for scheduling with definition such that // the definition's tensor data members refer to the corresponding IR objects // needed for scheduling. A simple copy of the container would mean the data diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index d7b9e8d1c34..b229107c45b 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -23,9 +23,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -779,6 +781,44 @@ void initNvFuserPythonBindings(PyObject* module) { defineHeuristicParamBindings(nvfuser); + py::class_ hyperparameters( + nvfuser, "SchedulerHyperParameters"); + hyperparameters.def(py::init()); + hyperparameters.def_property( + "vectorize_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.vectorize_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t vectorize_factor_) { + self.vectorize_factor = vectorize_factor_; + }); + hyperparameters.def_property( + "unroll_factor", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.unroll_factor; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t unroll_factor_) { self.unroll_factor = unroll_factor_; }); + hyperparameters.def_property( + "threads_per_block_min", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_min; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_min_) { + self.threads_per_block_min = threads_per_block_min_; + }); + hyperparameters.def_property( + "threads_per_block_max", + [](scheduler_utils::SchedulerHyperParameters& self) { + return self.threads_per_block_max; + }, + [](scheduler_utils::SchedulerHyperParameters& self, + int64_t threads_per_block_max_) { + self.threads_per_block_max = threads_per_block_max_; + }); + //! KernelProfiles are encapsulated in FusionProfiles where each KP //! is associated with a segment. py::class_ kernel_prof(nvfuser, "KernelProfile"); @@ -1401,7 +1441,7 @@ void initNvFuserPythonBindings(PyObject* module) { py::class_ nvf_ops(fusion_def, "Operators"); nvf_ops.def(py::init()); - // ******************** INSERT OP BINDINGS BELOW HERE ******************** +// ******************** INSERT OP BINDINGS BELOW HERE ******************** #define OP_PREFIX "Operators." #define NVFUSER_PYTHON_BINDING_UNARY_OP(op_str, op_name) \ nvf_ops.def( \ @@ -3822,6 +3862,27 @@ void initNvFuserPythonBindings(PyObject* module) { return *parameters->as(); }, py::return_value_policy::reference); + nvf_sched.def( + "schedule_hyperparameters", + [](FusionDefinition::SchedOperators& self) + -> scheduler_utils::SchedulerHyperParameters& { + NVF_CHECK( + self.validUse(), + "Attempting to use a SchedOperators Op prior to definition!"); + UserSchedule* sched = self.fusion_definition->userSchedule(); + auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>( + sched->data_cache.get(), []() { + return std::make_unique< + scheduler_utils::SchedulerHyperParameters>( + /*vectorize_factor=*/1, + /*unroll_factor=*/1, + /*threads_per_block_min=*/1, + /*threads_per_block_max=*/1); + }); + return scheduler_hyperparameters_entry.get(); + }, + py::return_value_policy::reference); } void cleanup() { diff --git a/csrc/scheduler/compile_time_info.h b/csrc/scheduler/compile_time_info.h index 18b5efb0e8e..d413c99ae81 100644 --- a/csrc/scheduler/compile_time_info.h +++ b/csrc/scheduler/compile_time_info.h @@ -46,7 +46,8 @@ enum class CompileTimeEntryType { CAN_SCHEDULE_TRANSPOSE, CAN_SCHEDULE_MUL_SUM_AS_MMA, LOGICAL_REORDER_MAP, - VECTORIZATION_BREAK_POINT_OF_RED_PROD + VECTORIZATION_BREAK_POINT_OF_RED_PROD, + SCHEDULE_HYPERPARAMETERS }; //! Entry type definition class for `DOMAIN_MAP`, @@ -203,6 +204,15 @@ class VectorizationBreakPointOfReductionProducer { CompileTimeEntryType::VECTORIZATION_BREAK_POINT_OF_RED_PROD; }; +//! Entry type definition class for `SCHEDULE_HYPERPARAMETERS`, +//! stores hyperparameters for SchedulerEntry::computeHeuristics +class SchedulerHyperParameters { + public: + using DataType = scheduler_utils::SchedulerHyperParameters; + static const CompileTimeEntryType EntryType = + CompileTimeEntryType::SCHEDULE_HYPERPARAMETERS; +}; + //! Base abstract class for unified storage in `HeuristicDataCache`, //! each entry in `HeuristicDataCache` will be a subclass. class CompileTimeInfoBase : public PolymorphicBase { diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 6dd34f4cab9..2ea854f0a88 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -186,7 +186,9 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( SchedulerRuntimeInfo& runtime_info, HeuristicDataCache* data_cache, const std::vector& reduction_tvs, - const int64_t vectorize_factor) { + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max) { FUSER_PERF_SCOPE( "normalization_inner_outer::getPersistentBufferStorageParams"); @@ -230,9 +232,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( const auto dev_prop = at::cuda::getCurrentDeviceProperties(); int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( - fusion, - reduction_tvs, - InnerOuterPersistentKernelScheduler::threads_per_block_max); + fusion, reduction_tvs, threads_per_block_max); int64_t available_smem = (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; int64_t available_regs = scheduler_utils::register_file_size_56k; @@ -281,8 +281,8 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( tv_buffer_size_regs, dataTypeSize(current_tv->getDataType().value()), vectorize_factor, - InnerOuterPersistentKernelScheduler::threads_per_block_min, - InnerOuterPersistentKernelScheduler::threads_per_block_max, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); buffer_params.smem_buffer_size += tv_buffer_size_smem; @@ -332,6 +332,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( const int64_t outer_dim_numel, const int64_t persistent_buffer_size, const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const int64_t warp_size) { // if inner_dim_numel <= 1024, we are doing multiple reductions per block // with a constant batch size of 1 if vectorized. See Step 5 of @@ -380,11 +382,8 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( }; const int64_t after_vectorization = inner_dim_numel / vectorize_factor; - const int64_t threads_per_block_min = std::min( - after_vectorization, - InnerOuterPersistentKernelScheduler::threads_per_block_min); - const int64_t threads_per_block_max = - InnerOuterPersistentKernelScheduler::threads_per_block_max; + const int64_t threads_per_block_min_after_vectorization = + std::min(after_vectorization, threads_per_block_min); const int64_t batch_min = getMinimumBatch(); const int64_t batch_max = getMaximumInnerOuterPersistentBufferBatch(); @@ -392,7 +391,7 @@ std::pair getBufferBatchSizeAndThreadsPerBlock( // is larger than batch_max, try increase threads per block by a warp until // the threads_per_block reaches threads_per_block_max or the batch size // reaches batch_min. - int64_t threads_per_block = threads_per_block_min; + int64_t threads_per_block = threads_per_block_min_after_vectorization; int64_t inner_batch = ceilDiv(after_vectorization, threads_per_block); while (inner_batch > batch_max && threads_per_block + warp_size <= threads_per_block_max && @@ -432,6 +431,8 @@ std::unique_ptr innerOuterPersistentHeuristic( const int64_t smem_overhead, const size_t tmp_gmem_dtype_size, const size_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max, const bool project_to_input, const PrimDataType index_type) { auto rparams = std::make_unique( @@ -512,6 +513,8 @@ std::unique_ptr innerOuterPersistentHeuristic( outer_dim_numel, regs_buffer_size, iop.inner_vect, + threads_per_block_min, + threads_per_block_max, dev_prop->warpSize); iop.inner_batch = persistent_batch; @@ -743,12 +746,32 @@ std::unique_ptr getInnerOuterPersistentHeuristics( scheduler_utils::persistentBuffers(fusion)); }); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + auto& persistent_buffer_info = persistent_buffer_info_entry.get(); NVF_ERROR( !persistent_buffer_info.persistent_buffers.empty(), "Persistent scheduler requires persistent buffers."); auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); std::unique_ptr rparams = innerOuterPersistentHeuristic( properties.total_iteration_numel, @@ -757,7 +780,9 @@ std::unique_ptr getInnerOuterPersistentHeuristics( buffer_params.smem_buffer_size, buffer_params.smem_overhead, max_outer_reduction_dtype_size, - vectorize_factor, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max, buffer_params.project_to_input, runtime_info.getIndexType()); @@ -1244,9 +1269,29 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( data_cache, (int)(reduced_tv->nDims() - properties.inner_most_dimension_ndims)); + auto scheduler_hyperparameters_entry = + HeuristicDataCacheEntry( + data_cache, [&]() { + return std::make_unique( + /*vectorize_factor=*/vectorize_factor, + /*unroll_factor=*/1, + /*threads_per_block_min=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_min, + /*threads_per_block_max=*/ + InnerOuterPersistentKernelScheduler::threads_per_block_max); + }); + scheduler_utils::SchedulerHyperParameters& hp = + scheduler_hyperparameters_entry.get(); + // check if there is enough register and shared memory for persistence const auto buffer_params = getPersistentBufferStorageParams( - fusion, runtime_info, data_cache, reduction_tvs, vectorize_factor); + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 32cb0aa30f0..3f6b5827db5 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -224,4 +224,6 @@ template class HeuristicDataCacheEntry< template class HeuristicDataCacheEntry; template class HeuristicDataCacheEntry< HeuristicCompileTime::VectorizationBreakPointOfReductionProducer>; +template class HeuristicDataCacheEntry< + HeuristicCompileTime::SchedulerHyperParameters>; } // namespace nvfuser diff --git a/csrc/scheduler/utils.h b/csrc/scheduler/utils.h index 5dab953dead..77317cde31b 100644 --- a/csrc/scheduler/utils.h +++ b/csrc/scheduler/utils.h @@ -173,6 +173,34 @@ inline void parallelizeAllLike( propagate_padding); } +// Common hyperparameters used in heuristic scheduler. These hyperparameters +// are passed to SchedulerEntry::computeHeuristics through the +// HeuristicDataCache. These hyperparameters alter the generation of the +// HeuristicParams for the scheduler. +struct SchedulerHyperParameters { + SchedulerHyperParameters( + int64_t vectorize_factor_, + int64_t unroll_factor_, + int64_t threads_per_block_min_, + int64_t threads_per_block_max_) + : vectorize_factor(vectorize_factor_), + unroll_factor(unroll_factor_), + threads_per_block_min(threads_per_block_min_), + threads_per_block_max(threads_per_block_max_) {} + + //! Number of elements to load per vectorize load. + int64_t vectorize_factor = 1; + + //! Number of iterations to unroll for-loop. + int64_t unroll_factor = 1; + + //! Minimum number of threads per block. + int64_t threads_per_block_min = 1; + + //! Maximum number of threads per block. + int64_t threads_per_block_max = 1; +}; + struct PersistentBufferInfo { std::vector persistent_buffers; std::unordered_set unmappable_dims; diff --git a/doc/dev/python_scheduling/autotune_persistent.py b/doc/dev/python_scheduling/autotune_persistent.py new file mode 100644 index 00000000000..9cb02c6c0e7 --- /dev/null +++ b/doc/dev/python_scheduling/autotune_persistent.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# Owner(s): ["module: nvfuser"] + +import torch +import itertools +import random +from nvfuser import FusionCache, FusionDefinition, SchedulerType, DataType +from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype +from copy import deepcopy + +# ============================ Description ============================ + +# 1. Define a nvfuser fusion and its pytorch eager mode reference. +# +# 2. Profile the CUDA kernel performance by iterating over a set of input +# arguments and scheduler configurations. +# +# 3. Train a regression model to predict the desired performance metric given +# some input arguments and a scheduler configuration. +# +# 4. Measure the performance of the regression model. +# - Calculate RMSE of predicted and actual performance on test set. +# - Find the configuration with the best performance using regression model. +# Then, compare against the heuristic configuration selected by nvfuser. +# - For a specific batch size, gather performance across a range of hidden +# sizes. Calculate performance for best predicted and nvfuser +# configurations. Plot a chart comparing performance using matplotlib. + +# The selected performance metric is effective_bandwidth_gbs. The empirical +# scheduler selects the configuration that has the highest predicted +# effective_bandwidth_gbs. + +# ============================ Configurations ============================ + +# Settings for input tensor generation +num_dimensions = 2 +outer_shapes = [256, 1024, 4096, 16384] +inner_shapes = [2**i for i in range(10, 15)] + +# For pointwise scheduler, we test the cartesian product of vectorization and +# cta_size factors. +parameter_configurations = [ + vectorize_range := [1, 2, 4, 8], + threads_per_cta_range := list(range(128, 288, 32)), +] + +# We profile a range of input shapes with various configurations. +# This argument determines how much of the profiled data to keep as a test set. +test_data_percentage = 0.1 + +# The selected batch size for empirical and nvfuser comparison. +empirical_batch_size = 512 + +# The range of hidden sizes for empirical and nvfuser comparision. +empirical_hidden_sizes = list(range(1024, 28672, 256)) + +# NOTE For 24gb memory limit +# empirical_hidden_sizes = list(range(256, 22784, 256)) + + +def create_inputs(shape): + """Create input arguments for nvfuser fusion and eager mode""" + a = torch.randn(*shape, dtype=torch.bfloat16, device="cuda", requires_grad=True) + grads = torch.randn(*shape, dtype=torch.bfloat16, device="cuda") + weights = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + bias = torch.randn( + shape[1], dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + + eps = 1e-5 + mean = a.to(torch.float).mean(dim=-1) + variance = a.to(torch.float).var(dim=-1, unbiased=False) + invstd = (1.0 / torch.sqrt(variance + eps)).unsqueeze(1) + + nvf_inputs = [a, grads, mean, invstd, weights] + eager_inputs = [a, weights, bias, grads] + return nvf_inputs, eager_inputs + + +# A decorator to create a pointwise fusion given some input arguments. +def create_fusion_func(inputs): + PROMOTE_DTYPES = [DataType.BFloat16, DataType.Half] + dtype = torch_dtype_to_nvfuser_dtype(inputs[0].dtype) + + def fusion_func(fd: FusionDefinition): + T0 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + T1 = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=dtype, is_cpu=False + ) + + T2 = fd.define_tensor( + shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False + ) + T3 = fd.define_tensor( + shape=[-1, 1], contiguity=[True, None], dtype=DataType.Float, is_cpu=False + ) + + T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=dtype, is_cpu=False) + + if dtype in PROMOTE_DTYPES: + T0 = fd.ops.cast(T0, dtype=DataType.Float) + T1 = fd.ops.cast(T1, dtype=DataType.Float) + T4 = fd.ops.cast(T4, dtype=DataType.Float) + + V8 = fd.define_vector([T0.size(0), 1], dtype=DataType.Int) + T9 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + V12 = T0.shape() + T13 = fd.ops.broadcast_in_dim(T9, shape=V12, broadcast_dims=[0, 1]) + T14 = fd.ops.sub(T0, T13) + + T18 = fd.ops.broadcast_in_dim(T3, shape=V12, broadcast_dims=[0, 1]) + T19 = fd.ops.mul(T14, T18) + + T23 = fd.ops.broadcast_in_dim(T4, shape=V12, broadcast_dims=[1]) + T28 = fd.ops.sum(T1, dims=[0], keepdim=False, dtype=DataType.Null) + + T30 = fd.ops.mul(T1, T23) + T31 = fd.ops.mul(T1, T19) + T32 = fd.ops.sum(T31, dims=[0], keepdim=False, dtype=DataType.Null) + + T34 = fd.ops.mul(T30, T18) + T35 = fd.ops.mul(T30, T14) + T36 = fd.ops.sum(T35, dims=[1], keepdim=False, dtype=DataType.Null) + + T40 = fd.ops.broadcast_in_dim(T36, shape=V8, broadcast_dims=[0]) + T41 = fd.ops.neg(T34) + T42 = fd.ops.sum(T41, dims=[1], keepdim=False, dtype=DataType.Null) + T46 = fd.ops.broadcast_in_dim(T42, shape=V8, broadcast_dims=[0]) + S47 = fd.define_scalar(-0.500000, dtype=DataType.Double) + T48 = fd.ops.mul(S47, T40) + S49 = fd.define_scalar(3.00000, dtype=DataType.Double) + T50 = fd.ops.pow(T3, S49) + T51 = fd.ops.mul(T48, T50) + T54 = fd.ops.sum(T46, dims=[1], keepdim=False, dtype=DataType.Null) + T55 = fd.ops.sum(T51, dims=[1], keepdim=False, dtype=DataType.Null) + + T59 = fd.ops.broadcast_in_dim(T55, shape=V8, broadcast_dims=[0]) + T63 = fd.ops.broadcast_in_dim(T59, shape=V12, broadcast_dims=[0, 1]) + T67 = fd.ops.broadcast_in_dim(T2, shape=V8, broadcast_dims=[0]) + T71 = fd.ops.broadcast_in_dim(T67, shape=V12, broadcast_dims=[0, 1]) + + S72 = fd.define_scalar(2.00000, dtype=DataType.Double) + T73 = fd.ops.mul(S72, T63) + T74 = fd.ops.sub(T0, T71) + T75 = fd.ops.mul(T73, T74) + + S77 = fd.ops.reciprocal(T0.size(1)) + T78 = fd.ops.mul(T75, S77) + T82 = fd.ops.broadcast_in_dim(T54, shape=V8, broadcast_dims=[0]) + T86 = fd.ops.broadcast_in_dim(T82, shape=V12, broadcast_dims=[0, 1]) + T88 = fd.ops.mul(S77, T86) + T89 = fd.ops.add(T78, T88) + T90 = fd.ops.add(T34, T89) + + if dtype in PROMOTE_DTYPES: + T28 = fd.ops.cast(T28, dtype=dtype) + T90 = fd.ops.cast(T90, dtype=dtype) + T32 = fd.ops.cast(T32, dtype=dtype) + + fd.add_output(T90) + fd.add_output(T32) + fd.add_output(T28) + + return fusion_func + + +# The pytorch eager mode reference used to validating nvfuser kernel. +def eager_reference(inputs): + inputs_cloned = deepcopy(inputs) + a, weights, bias, grad_output = inputs_cloned + eager_output = torch.nn.functional.layer_norm( + a.to(torch.double), + a.shape[1:], + weight=weights.to(torch.double), + bias=bias.to(torch.double), + ) + grad_output = grad_output.to(torch.double) + eager_output.backward(grad_output) + return [a.grad, weights.grad, bias.grad] + + +# ============================ Function Definitions ============================ + + +# Apply scheduler with custom parameters using decorator +def custom_persistent_scheduler(fd, config): + def inner_fn(): + # Check if compatible with persistent scheduler + status, _ = fd.sched.can_schedule(SchedulerType.inner_outer_persistent) + assert status + + # Modify original parameters + if config is not None: + hyperparameters = fd.sched.schedule_hyperparameters() + vectorize_factor, threads_per_block = config + hyperparameters.vectorize_factor = vectorize_factor + hyperparameters.threads_per_block_min = threads_per_block + hyperparameters.threads_per_block_max = threads_per_block + + # Schedule fusion + fd.sched.schedule(SchedulerType.inner_outer_persistent) + + fd.schedule = inner_fn + return fd + + +# Apply schedule decorator, run fusion, and profile performance +def run_profile(presched_fd, nvf_inputs, eager_inputs, config=None): + scheduled_fd = custom_persistent_scheduler(presched_fd, config) + nvf_outputs = scheduled_fd.execute(nvf_inputs, profile=True) + + # validate correctness + ref_outputs = eager_reference(eager_inputs) + for nvf_out, ref_out in zip(nvf_outputs, ref_outputs): + assert torch.allclose(nvf_out, ref_out, atol=1e-1, rtol=1e-1) + + prof = scheduled_fd.profile() + bandwidth = prof.kernel_profiles[0].effective_bandwidth_gbs + time = prof.kernel_profiles[0].time_ms + return bandwidth, time + + +def argmax(map_config_to_perf): + best_perf = -1 + best_config = None + for config, perf in map_config_to_perf.items(): + if perf > best_perf: + best_perf = perf + best_config = config + return best_config + + +# Given a prediction model, input_shape, and set of parameter configurations, +# find the best parameters +def find_best_parameters(predictor, input_shape, parameter_configurations): + map_config_to_performance = { + config: predictor.predict([[*input_shape, *config]]) + for config in itertools.product(*parameter_configurations) + } + return argmax(map_config_to_performance) + + +# ============================ Run Experiments ================================ + +# Collect data for decision tree +parameters = [] +performance = [] + +for shape in itertools.product(outer_shapes, inner_shapes): + print(shape) + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + # vectorization and threads_per_cta configurations + for config in itertools.product(*parameter_configurations): + perf_metric, _ = run_profile(presched_fd, nvf_inputs, eager_inputs, config) + parameters.append((*shape, *config)) + performance.append(perf_metric) + +# ============================ Separate Data ================================== + +# Separate collected data into training and test sets +train_data = [] +test_data = [] +train_perf = [] +test_perf = [] +test_shapes = set() +all_test_config = {} # key: input_shape, value: (config, perf) + +for data, perf in zip(parameters, performance): + shape = data[:num_dimensions] + config = data[num_dimensions:] + + if shape in all_test_config: + all_test_config[shape][config] = perf + else: + all_test_config[shape] = {config: perf} + + if random.random() < test_data_percentage: + test_data.append(data) + test_perf.append(perf) + else: + test_shapes.add(shape) + train_data.append(data) + train_perf.append(perf) + +# key: input_shape, value: best_config +best_test_config = {shape: argmax(all_test_config[shape]) for shape in test_shapes} + +# ========================= Train Regression Models =========================== + +# Apply decision tree regressor +# Given input shapes and scheduler parameters, predict performance metric. +from sklearn import tree + +clf = tree.DecisionTreeRegressor() +clf = clf.fit(train_data, train_perf) +test_pred = clf.predict(test_data) + +print("===================== measure performance rmse ========================") + +# Estimate prediction error with RMSE +import numpy as np + +test_perf = np.array(test_perf) +print( + "Test prediction error (RMSE)", + np.sqrt(np.mean(np.power(test_perf - test_pred, 2))), +) +print("Test performance", test_perf) +print("Test prediction", test_pred) + +print("======================= compare configurations =======================") +# Find best configuration for test_shapes +print( + "input shape, estimate_config:(vectorization, cta_size), actual_config:(vectorization, cta_size), correct" +) +correctness_count = 0 +mismatch_configs = [] +for shape in test_shapes: + estimate_config = find_best_parameters(clf, shape, parameter_configurations) + + match_config = estimate_config == best_test_config[shape] + if not match_config: + mismatch_configs.append((shape, estimate_config)) + + correctness_count += int(match_config) + print(f"{shape}, {estimate_config}, {best_test_config[shape]}, {match_config}") +print("% of predictions match nvfuser parameters", correctness_count / len(test_shapes)) +print(correctness_count, "out of", len(test_shapes)) + +print("======================= compare performance =========================") + +for shape, estimate_config in mismatch_configs: + nvf_inputs, eager_inputs = create_inputs(shape) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_perf = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + _, nvf_perf = run_profile(presched_fd, nvf_inputs, eager_inputs) + est_perf_faster = est_perf < nvf_perf + print( + f"{shape} \t estimate_perf:{est_perf:.5f} \t nvfuser_perf:{nvf_perf:.5f} \t is_estimated_config_faster:\t{est_perf_faster}" + ) + +print("=====================================================================") + +# For a specific batch size, gather performance across a range of hidden sizes. +# Calculate performance for best predicted and nvfuser configurations. Plot a +# chart comparing performance using matplotlib. + +# NOTE: The matplotlib experiment plots the kernel runtime, which could be +# different than the selected performance metric. Currently, the performance +# metric is effective_bandwidth_gbs. + +import matplotlib.pyplot as plt +import numpy as np + +# Avoid reusing any cached, user-scheduled fusions to have a clean run. +FusionCache.reset() +est_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, est_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs, estimate_config) + est_perfs.append(est_time_ms) + print( + f"decision tree: {empirical_batch_size}, {hidden_shape}, {estimate_config}, {est_time_ms:.3f}" + ) + +FusionCache.reset() +nvf_perfs = [] +for hidden_shape in empirical_hidden_sizes: + nvf_inputs, eager_inputs = create_inputs((empirical_batch_size, hidden_shape)) + estimate_config = find_best_parameters( + clf, (empirical_batch_size, hidden_shape), parameter_configurations + ) + + with FusionDefinition() as presched_fd: + create_fusion_func(nvf_inputs)(presched_fd) + + _, nvf_time_ms = run_profile(presched_fd, nvf_inputs, eager_inputs) + nvf_perfs.append(nvf_time_ms) + print(f"nvfuser: {empirical_batch_size}, {hidden_shape}, {nvf_time_ms:.3f}") + +# Get mean speed-up from nvfuser to empirical configurations across all input shapes. +# Negative value mean empirical configurations are slower than nvfuser. +print("Mean speed-up", np.mean(np.array(nvf_perfs) - np.array(est_perfs))) + +np_hidden_size = np.array(empirical_hidden_sizes) +plt.plot(np_hidden_size, np.array(est_perfs)) +plt.plot(np_hidden_size, np.array(nvf_perfs)) + +plt.xlabel("Hidden Size") +plt.ylabel("Time(ms)") +plt.title( + f"Batch Size = {empirical_batch_size}, Compare Decision Tree Heuristic vs NvFuser" +) +plt.legend(["decision_tree", "nvfuser"], loc="lower right") +plt.savefig(f"persistent_inner_outer_empirical_batchsize{empirical_batch_size}.png") + +# ============================================================================= From d88dcba805c0ea5a72e94c78ae03df32f6b3a2c2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Oct 2024 08:03:05 -0400 Subject: [PATCH 04/15] Disable nvfusertest_serde_check if DEBUG_SERDE=disable (#3304) This is another attempt to fix the codediff CI job Fixes #3265. Fixes #3283. --- tests/python/utils.py | 20 +++++++++++++++++--- tools/codediff/compare_codegen.sh | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/python/utils.py b/tests/python/utils.py index 4cb0c0e4cb3..2a7fadc4a14 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -274,6 +274,7 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): # This DEBUG_SERDE environment flag is used to debug serialization failures. # +# If DEBUG_SERDE=debug # 1) It disables automatically saving FusionCache upon program exit. Therefore, # it has to be a global flag not per-test. # @@ -283,8 +284,14 @@ def check_cpp_translation(reference_outputs, fd, inputs, device=None): # # 3) It keeps the temporary files that are created during serde_check. # Normally, these files are deleted after each test. -env_var_debug_serde = os.getenv("DEBUG_SERDE") -debug_serde: bool = env_var_debug_serde in ("true", "1") +# +# DEBUG_SERDE=disable +# 1) It disables the @nvfusertest_serde_check decorator. This disables checking +# that serde round-trips preserve the definition during testing. +env_var_debug_serde = os.getenv("DEBUG_SERDE", "").lower() +debug_serde: bool = env_var_debug_serde == "debug" +disable_serde: bool = env_var_debug_serde == "disable" +del env_var_debug_serde # The pytest framework and test_python_frontend.py use different arguments for @@ -314,7 +321,7 @@ def basic_serde_check(): ) else: raise RuntimeError( - "***** Use DEBUG_SERDE=true to debug serialization failure." + "***** Use DEBUG_SERDE=debug to debug serialization failure." ) @@ -323,6 +330,11 @@ def basic_serde_check(): # binary. Call FusionCache.reset() to clear the cache after running an error # test in `test_python_frontend.py'. def atexit_serde_check(): + if disable_serde: + # Ignore FusionCache and automatic serialization if serde check is + # disabled + return + from nvfuser import FusionCache if not debug_serde: @@ -343,6 +355,8 @@ def nvfusertest_serde_check(test_fn: Callable): function. Currently, it uses serialization to rebuild the FusionCache structure. """ + if disable_serde: + return test_fn def inner_fn(*args, **kwargs): self, fusion_func, inputs = args diff --git a/tools/codediff/compare_codegen.sh b/tools/codediff/compare_codegen.sh index 8ae33f8805c..478936a047d 100755 --- a/tools/codediff/compare_codegen.sh +++ b/tools/codediff/compare_codegen.sh @@ -189,7 +189,7 @@ collect_kernels() { # Make tests reproducible export NVFUSER_TEST_RANDOM_SEED=0 export NVFUSER_DISABLE=parallel_compile - export DEBUG_SERDE=true + export DEBUG_SERDE=disable # run tests and benchmarks with cuda_to_file and dump output to files mkdir -p "$outdir/$commit" From f394b4e382c03e9fcec55d31c584d12f13c4e1de Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:46:40 -0400 Subject: [PATCH 05/15] reorder outer reduction tv in inner-outer scheduler when there are view ops in the fusion (#3287) **Issue**: The IDs involved in view transforms are moved to the outer most positions in the loop domain, e.g. `T0[i0, i2, i3] --> T3[i0, i2*i3]`, `propagateReshapeTransforms` moves `{i2*i3}` to the outer most position and the loop domain becomes ` (iS31{( i2 * i3 )}, iS0{i0})`. To maintain the original reduction axis, for reduction tv we should reorder the loop domain back to its original logical domain, this is done for inner reduction in innerOuter scheduler but not for outer reduciton. (1) Inner reduction tv after `propagateReshapeTransforms` ``` T3_l_float[ rS13{( i2 * i3 )}, iS12{i0} ] logical domain : (iS12{i0}, rS13{( i2 * i3 )}) contiguity: t n loop domain : (rS13{( i2 * i3 )}, iS12{i0}) ``` (2) Inner reduction tv after `reorder(domainReorderAsLogicalMap)` ``` T3_l_float[ iS12{i0}, rS13{( i2 * i3 )} ] logical domain : (iS12{i0}, rS13{( i2 * i3 )}) contiguity: t n loop domain : (iS12{i0}, rS13{( i2 * i3 )}) ``` (3) Outer reduction tv after `propagateReshapeTransforms` ``` T6_l_float[ iS19{( i2 * i3 )}, rS18{i0} ] logical domain : (rS18{i0}, iS19{( i2 * i3 )}) contiguity: n t loop domain : (iS19{( i2 * i3 )}, rS18{i0}) ``` `reorder(domainReorderAsLogicalMap)` is not used for Outer reduction tv. This leads to error `Cannot rfactor axes that are not reduction axes.` when the scheduler tries to rFactor outer dim, which is `iS19{( i2 * i3 )` **Fix**: Add `reorder(domainReorderAsLogicalMap)` for outer reduction tv. **Results**: Added a unit test, err is fixed. Outer reduction tv is correctly reordered as: ``` T6_l_float[ rS18{i0}, iS19{( i2 * i3 )} ] logical domain : (rS18{i0}, iS19{( i2 * i3 )}) contiguity: n t loop domain : (rS18{i0}, iS19{( i2 * i3 )}) ``` --- csrc/scheduler/normalization_inner_outer.cpp | 9 ++++ .../test_combined_inner_outer_reduction.cpp | 46 +++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 2ea854f0a88..e4bbc803033 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -815,6 +815,15 @@ void scheduleReductionCombinedOuter( } }; for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + // merge tensorview to [reduction, iteraiton] domains mergeReductionOrIterDomains(outer_reduction_tv, true); mergeReductionOrIterDomains(outer_reduction_tv, false); diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 2071aeb0e86..95eaadd4ad7 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -994,4 +994,50 @@ TEST_F(CombinedSchedulerTest, SharedMemoryPersistentVectFactor) { aten_inputs, heuristic_params->as()->lparams); testValidate(&fusion_copy, cg_outputs, aten_inputs, __LINE__, __FILE__); } + +using InnerOuterReshapeTest = NVFuserFixtureParamTest; +INSTANTIATE_TEST_SUITE_P( + , + InnerOuterReshapeTest, + testing::Bool(), + testing::PrintToStringParamName()); +TEST_P(InnerOuterReshapeTest, ReshapeOuterDimTrueOrFalse) { + auto reshape_outer_dim = GetParam(); + Fusion fusion; + FusionGuard fg(&fusion); + // reshape a 3D input tensor to 2D + // [4, 1024, 4096] -> [4096, 4096] + // [4096, 4, 1024] -> [4096, 4096] + const int dim0 = reshape_outer_dim ? 4 : 4096; + const int dim1 = reshape_outer_dim ? 1024 : 4; + const int dim2 = reshape_outer_dim ? 4096 : 1024; + auto dtype = DataType::Half; + auto tv0 = makeContigTensor(3, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + + auto tv4 = reshape(tv1, {dim0, dim1, dim2}, {4096, 4096}); + + auto tv5 = sum(tv4, {1}); + auto tv6 = broadcast(tv5, {false, true}); + auto tv7 = add(tv6, tv4); + auto tv8 = sum(tv4, {0}); + auto tv9 = castOp(DataType::Half, tv7); + auto tv10 = castOp(DataType::Half, tv8); + fusion.addOutput(tv9); + fusion.addOutput(tv10); + + Fusion fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dim0, dim1, dim2}, options); + std::vector aten_inputs = {t0}; + auto cg_results = + scheduleAndRun(&fusion, SchedulerType::InnerOuterPersistent, aten_inputs); + auto persistent_params = cg_results.heuristic_params->as(); + ASSERT_FALSE(persistent_params->project_persistent_buffers); + testValidate( + &fusion_copy, cg_results.outputs, aten_inputs, __LINE__, __FILE__); +} + } // namespace nvfuser From a4465df112ea6ecdb9dc47cb1cc4e8c2ffa3e162 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 30 Oct 2024 12:20:26 -0700 Subject: [PATCH 06/15] Host benchmarking for a fusion with multiple segments (#3307) This benchmark uses matmul + pointwise op to create a fusion with 12 segments instead of using `segment_set` to force segmentation. ![Screenshot 2024-10-29 at 4 41 47 PM](https://github.com/user-attachments/assets/2e65f8b9-489b-431b-8694-ab265f90ce32) For `host_benchmark_mode='compile'`, the profile is shown below. The `Finding valid segment solutions` pass takes 52 ms ![Screenshot 2024-10-29 at 4 52 38 PM](https://github.com/user-attachments/assets/eb02a309-2b17-4c7b-8a89-325a462323cb) --- benchmarks/python/test_many_segments_host.py | 83 ++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 benchmarks/python/test_many_segments_host.py diff --git a/benchmarks/python/test_many_segments_host.py b/benchmarks/python/test_many_segments_host.py new file mode 100644 index 00000000000..9515da141a0 --- /dev/null +++ b/benchmarks/python/test_many_segments_host.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import pytest +from nvfuser import FusionDefinition, DataType +from .core import run_benchmark +import torch + + +def many_matmul_fusion(fd: FusionDefinition) -> None: + x = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False + ) + y = fd.define_tensor( + shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False + ) + a = fd.ops.add(x, y) + for _ in range(5): + a_transpose = fd.ops.permute(a, [1, 0]) + matmul_out = fd.ops.matmul(a_transpose, y) + add_out = fd.ops.add(a_transpose, y) + a = fd.ops.add(matmul_out, add_out) + fd.add_output(a) + + +@pytest.mark.parametrize("host_bench_mode", ["compile", "steady", "dynamic"]) +def test_many_segment_benchmark( + benchmark, + host_bench_mode: str, + disable_validation: bool, + disable_benchmarking: bool, +): + inputs = [torch.randn(16, 16, device="cuda", dtype=torch.float) for _ in range(2)] + + # Generate multiple inputs to measure dynamic shape overhead. + if host_bench_mode == "dynamic": + input_sizes = [4, 8, 16, 32, 64, 128] + # Generate matrices of size x size dimensions + inputs = [ + [ + torch.randn(size, size, device="cuda", dtype=torch.float) + for _ in range(2) + ] + for size in input_sizes + ] + + with FusionDefinition() as fd: + many_matmul_fusion(fd) + + def validate(input): + x, y = input + eager_output = x + y + for _ in range(5): + eager_transpose = eager_output.t() + matmul_out = torch.matmul(eager_transpose, y) + add_out = eager_transpose + y + eager_output = matmul_out + add_out + fd.validate(input, [eager_output]) + + # Validate number of segments + _ = fd.execute(input, profile=True) + num_segments = fd.profile().segments + expected_segments = 12 + assert ( + num_segments == expected_segments + ), f"Expected {expected_segments} fusion segments, got {num_segments}." + + if not disable_validation: + if host_bench_mode == "dynamic": + # Run validate for all input sizes. + for input in inputs: + validate(input) + else: + validate(inputs) + + if not disable_benchmarking: + run_benchmark( + benchmark, + None, + inputs, + device=f"host:{host_bench_mode}", + fusion_fn=many_matmul_fusion, + ) From 81dd1d288cf4fda112cc3488b30f33e90fb23ae6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:49:25 -0400 Subject: [PATCH 07/15] Use deep evaluation of extents in remove_empty pass (#3301) For dynamic fusions, we detect empty tensors and set their extents to immediate constant 0. Later, in the remove_empty preseg pass, we do a shallow check that extents are empty so that we can simplify the fusion. When the fusion is not dynamic there is no concretization step where we would do this extent replacement, so we might have constant 0 extents that are compound scalars. This caused us to miss some empty tensors in #3292, particularly one of the inputs to a `cat`. This PR: - Uses a deep evaluation of each `getMaybeExpandedExtent()` to determine if an axis is empty - Adds an ExpressionEvaluator field to `EmptyTensorRemover` to avoid repeating the deep evaluation when possible. This won't help prevent repeated evaluation of symbolic extents; we could track those in an `unordered_set` potentially instead. Fixes #3292 --------- Co-authored-by: Naoya Maruyama --- csrc/preseg_passes/remove_empty.cpp | 63 +++++++++++++---------- csrc/serde/fusion_record.cpp | 3 +- tests/python/test_python_frontend.py | 74 ++++++++++++++++++++++++++++ 3 files changed, 113 insertions(+), 27 deletions(-) diff --git a/csrc/preseg_passes/remove_empty.cpp b/csrc/preseg_passes/remove_empty.cpp index 7893d993ded..0be346ee71e 100644 --- a/csrc/preseg_passes/remove_empty.cpp +++ b/csrc/preseg_passes/remove_empty.cpp @@ -7,10 +7,12 @@ // clang-format on #include +#include #include #include #include #include +#include #include #include @@ -21,29 +23,6 @@ namespace nvfuser::preseg_passes { namespace { -//! Get a vector of the integer positions of constant zero extent axes in the -//! input domain. This will typically be used like -//! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))` -std::vector emptyAxes(const std::vector& domain) { - std::vector empty_axes; - for (auto ax : c10::irange(domain.size())) { - auto id = domain.at(ax); - if (id->getMaybeExpandedExtent()->isConst() && - id->getMaybeExpandedExtent()->evaluate().as() == 0) { - empty_axes.push_back((int64_t)ax); - } - } - return empty_axes; -} - -//! Check whether a TensorView is empty. During concretization, we traverse to -//! find a minimal set of TensorViews that have zero extents, and we then set -//! their extents to a constant 0. Here we check for those constant zero -//! extents. -bool isTVEmpty(TensorView* tv) { - return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain())).empty(); -} - //! EmptyTensorRemover performs a backward traversal of the Fusion. When it //! detects a TensorView that has at least one extent that is zero, we do the //! following: @@ -69,9 +48,34 @@ class EmptyTensorRemover : public DeadCodeRemover { public: EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {} - protected: + private: using DeadCodeRemover::handle; + //! Get a vector of the integer positions of constant zero extent axes in the + //! input domain. This will typically be used like + //! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))` + std::vector emptyAxes(const std::vector& domain) { + std::vector empty_axes; + for (auto ax : c10::irange(domain.size())) { + auto id = domain.at(ax); + PolymorphicValue extent = + expr_eval_.evaluate(id->getMaybeExpandedExtent()); + if (extent.hasValue() && extent.as() == 0) { + empty_axes.push_back((int64_t)ax); + } + } + return empty_axes; + } + + //! Check whether a TensorView is empty. During concretization, we traverse to + //! find a minimal set of TensorViews that have zero extents, and we then set + //! their extents to a constant 0. Here we check for those constant zero + //! extents. + bool isTVEmpty(TensorView* tv) { + return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain())) + .empty(); + } + //! If tv is a fusion output, we check whether it is empty and if so, replace //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. @@ -257,8 +261,9 @@ class EmptyTensorRemover : public DeadCodeRemover { "Inputs to CatOp must be outputs of PadOps"); auto tv = inp->definition()->as()->in()->as(); auto cat_id = TensorDomain::noReductions(tv->getLogicalDomain()).at(dim); - if (cat_id->getMaybeExpandedExtent()->isConst() && - cat_id->getMaybeExpandedExtent()->evaluate().as() == 0) { + PolymorphicValue extent = + expr_eval_.evaluate(cat_id->getMaybeExpandedExtent()); + if (extent.hasValue() && extent.as() == 0) { continue; } non_empty_inputs.push_back(tv); @@ -312,6 +317,12 @@ class EmptyTensorRemover : public DeadCodeRemover { registerReplacement(out, new_tv); } } + + private: + // We use this ExpressionEvaluator without binding any inputs. This lets us + // quickly repeatedly evaluate compound constant expressions like + // ( fmax(0, ( fmin(( ceilDiv(576, 9) ), 0) )) ) + ExpressionEvaluator expr_eval_; }; } // namespace diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index f40edaaea44..5de2cda9873 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -43,7 +43,8 @@ python_frontend::RecordFunctor* deserializeOpRecord( const RecordFunctor* buffer) { NVF_ERROR( str_to_func_map.find(buffer->name()->str()) != str_to_func_map.end(), - "Missing mapping from operation string to nvfuser function in serde deserialization."); + "Missing mapping from operation string to nvfuser function in serde deserialization: ", + buffer->name()->str()); return new python_frontend::OpRecord( parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs()), diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 8080c48278c..874223471eb 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4600,3 +4600,77 @@ def fusion_func(fd: FusionDefinition) -> None: nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) for out in nvf_out: self.assertTrue(out.allclose(x[:, 1:, 2:])) + + def test_issue_3292(self): + inputs = [ + torch.testing.make_tensor( + (5, 5, 576), dtype=torch.float32, device="cuda:0" + ), + ] + + def fusion_func(fd: FusionDefinition) -> None: + T2 = fd.define_tensor( + shape=[5, 5, 576], + contiguity=[True, True, True], + dtype=DataType.Float, + is_cpu=False, + stride_order=[2, 1, 0], + ) + T30 = fd.ops.reshape(T2, new_shape=[5, 5, 1, 9, 64]) + T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4]) + T50 = fd.ops.slice( + T31, + start_indices=[0, 0, 0, 0, 0], + end_indices=[5, 1, 7, 5, 64], + strides=[1, 1, 1, 1, 1], + manual_normalization=0, + ) + T108 = fd.ops.reshape(T50, new_shape=[5, 7, 5, 64]) + T136 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 32], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T152 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 32], + end_indices=[5, 7, 5, 64], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T153 = fd.ops.neg(T152) + T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0) + T161 = fd.ops.mul(T108, T108) + T168 = fd.ops.mul(T154, T154) + T169 = fd.ops.add(T161, T168) + T185 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 32], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T201 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 32], + end_indices=[5, 7, 5, 64], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T202 = fd.ops.neg(T201) + T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0) + T205 = fd.ops.mul(T203, T203) + T222 = fd.ops.slice( + T108, + start_indices=[0, 0, 0, 0], + end_indices=[5, 7, 5, 0], + strides=[1, 1, 1, 1], + manual_normalization=0, + ) + T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0) + fd.add_output(T223) + + # is_clonable=False is because translation fails with missing ceilDiv + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False) From bad9e50bc9539054050310f423317b4c1d259c53 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 30 Oct 2024 13:57:59 -0700 Subject: [PATCH 08/15] Factorize ExpressionEvaluator::bind_. (#3305) No functionality changes. --- csrc/expr_evaluator.cpp | 143 +++++++++++++++++++++------------------- csrc/expr_evaluator.h | 18 +++-- 2 files changed, 86 insertions(+), 75 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 0fd62022098..d4ca6daa022 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -129,6 +129,79 @@ void validateValWithConcreteValue( } // namespace +void ExpressionEvaluator::bindTensorDomain( + const TensorView* tv, + const at::Tensor& t, + const bool evaluate_validate) { + auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); + NVF_ERROR( + t.dim() == (int64_t)logical_domain.size(), + "Expected ", + getInputPosString(tv), + tv->toString(), + ", to be bound to a tensor of rank ", + logical_domain.size(), + ", but got a tensor of rank ", + t.dim()); + for (auto i : c10::irange(t.dim())) { + auto id = logical_domain[i]; + if (id->isBroadcast()) { + // DIDs are ignored for broadcast. + bind_(logical_domain[i]->extent(), 1, evaluate_validate); + if (id->hasExpandedExtent()) { + // Verify that t is also expanded + NVF_ERROR( + t.size(i) == 1 || t.stride(i) == 0, + "IterDomain ", + id->toString(), + " in ", + getInputPosString(tv), + "TensorView ", + tv->toString(), + " has expanded extent but input tensor has size ", + t.size(i), + " and stride ", + t.stride(i), + " in dimension ", + i); + bind_( + logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); + } + } else { + if (logical_domain[i]->isDeviceDim()) { + // Currently we have the restrictions: + // (1) Devices parallelized axis extent == DeviceMesh's extent + // (2) Device parallelized axis cannot be split or merged + // Therefore, the device parallelized extents will always be allocated + // with size 1, but the symbolic axis extent is binded with the extent + // of the DeviceMesh + NVF_CHECK( + 1 == t.size(i), + "TensorView ", + tv->toString(), + getInputPosString(tv), + " IterDomain ", + id->toString(), + "is sharded and must have size 1, but input tensor has size ", + t.size(i)); + NVF_CHECK( + tv->hasDeviceMesh(), + "TV ", + tv->toString(), + getInputPosString(tv), + " has an empty DeviceMesh with DID parallelization") + bind_( + logical_domain[i]->extent(), + static_cast( + tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), + evaluate_validate); + } else { + bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + } + } + } +} + void ExpressionEvaluator::bind_( const Val* value, PolymorphicValue concrete_value, @@ -162,75 +235,7 @@ void ExpressionEvaluator::bind_( } if (auto tv = dynamic_cast(value)) { const auto& t = concrete_value.as(); - auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); - NVF_ERROR( - t.dim() == (int64_t)logical_domain.size(), - "Expected ", - getInputPosString(tv), - tv->toString(), - ", to be bound to a tensor of rank ", - logical_domain.size(), - ", but got a tensor of rank ", - t.dim()); - for (auto i : c10::irange(t.dim())) { - auto id = logical_domain[i]; - if (id->isBroadcast()) { - // DIDs are ignored for broadcast. - bind_(logical_domain[i]->extent(), 1, evaluate_validate); - if (id->hasExpandedExtent()) { - // Verify that t is also expanded - NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, - "IterDomain ", - id->toString(), - " in ", - getInputPosString(tv), - "TensorView ", - tv->toString(), - " has expanded extent but input tensor has size ", - t.size(i), - " and stride ", - t.stride(i), - " in dimension ", - i); - bind_( - logical_domain[i]->expandedExtent(), - t.size(i), - evaluate_validate); - } - } else { - if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast(tv->getDeviceMesh().size( - logical_domain[i]->getParallelType())), - evaluate_validate); - } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); - } - } - } + bindTensorDomain(tv, t, evaluate_validate); } if (value->isA()) { known_named_scalars_[value->as()->name()] = diff --git a/csrc/expr_evaluator.h b/csrc/expr_evaluator.h index ef1114bb8ff..b6c8e1857ea 100644 --- a/csrc/expr_evaluator.h +++ b/csrc/expr_evaluator.h @@ -25,12 +25,6 @@ class PrecomputedValues; //! Calculate Fusion IR expressions class ExpressionEvaluator { - NVF_API void bind_( - const Val* value, - PolymorphicValue concrete_value, - bool evaluate_validate); - void bind_(const std::string& name, PolymorphicValue concrete_value); - public: //! Bind a concrete value to an IR variable //! If evaluate_validate is true, and value is evaluatable with the @@ -98,6 +92,18 @@ class ExpressionEvaluator { ExpressionEvaluator clone(IrCloner& ir_cloner) const; private: + void bind_( + const Val* value, + PolymorphicValue concrete_value, + bool evaluate_validate); + + void bind_(const std::string& name, PolymorphicValue concrete_value); + + void bindTensorDomain( + const TensorView* tv, + const at::Tensor& t, + bool evaluate_validate); + const PolymorphicValue& getValue( const Val* value, const std::unordered_map& From ca70ad7f70251294adc38e0b71df6b83c5fe8c05 Mon Sep 17 00:00:00 2001 From: "Wang, Xiao" <24860335+xwang233@users.noreply.github.com> Date: Wed, 30 Oct 2024 15:37:45 -0700 Subject: [PATCH 09/15] Update !build and !test triggers for CI pipelines; add a CI hello message for pull requests (#3315) per title; wiki and backend trigger will be updated later --- .github/workflows/nvfuser-ci-trigger.yml | 2 +- .github/workflows/pull.yml | 25 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/pull.yml diff --git a/.github/workflows/nvfuser-ci-trigger.yml b/.github/workflows/nvfuser-ci-trigger.yml index c74472a881e..d4a79b95a9c 100644 --- a/.github/workflows/nvfuser-ci-trigger.yml +++ b/.github/workflows/nvfuser-ci-trigger.yml @@ -16,7 +16,7 @@ jobs: # This job only runs for pull request comments if: | - startsWith(github.event.comment.body, '!build') && + ( startsWith(github.event.comment.body, '!build') || startsWith(github.event.comment.body, '!test') ) && (github.actor == 'xwang233' || github.actor == 'jjsjann123' || github.actor == 'chang-l' || github.actor == 'csarofeen' || github.actor == 'drzejan2' || github.actor == 'IvanYashchuk' || github.actor == 'jacobhinkle' || github.actor == 'kevinstephano' || github.actor == 'liqiangxl' || github.actor == 'mmigdal-nv' || github.actor == 'naoyam' || github.actor == 'ptrblck' || github.actor == 'rdspring1' || github.actor == 'samnordmann' || github.actor == 'zasdfgbnm' || github.actor == 'crcrpar' || github.actor == 'nWEIdia' || github.actor == 'Priya2698' || github.actor == 'wujingyue' || github.actor == 'tfogal' || github.actor == 'protonu' || github.actor == 'cowanmeg' || github.actor == 'nsarka') steps: - name: Check if comment is issued by authorized person diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml new file mode 100644 index 00000000000..d27de8781ed --- /dev/null +++ b/.github/workflows/pull.yml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# A workflow to send CI-related helpful information to PRs +name: pull +on: + pull_request: + +run-name: CI status hello ${{ github.event.pull_request.number }} - ${{ github.event.pull_request.head.sha }} +jobs: + status_hello: + name: send CI hello status + runs-on: ubuntu-latest + permissions: + statuses: write + steps: + - name: Set CI hello status + run: | + curl \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.event.pull_request.head.sha }} \ + -d "{\"state\":\"success\",\"target_url\":\"https://github.com/NVIDIA/Fuser/wiki/Bot-Commands\",\"description\":\"Authorized users: comment !build or !test to trigger CI pipelines. See wiki.\",\"context\":\"CI notes\"}" From 3d9677de79b092101358863119bdef588ed3bccb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 30 Oct 2024 20:07:16 -0400 Subject: [PATCH 10/15] Add script to check for non-determinism (#3312) For example, try running the following commands: ```bash rm -rf /tmp/nvfuser_kernel_db tools/check_determinism.sh -- pytest -vs tests/python/test_ops.py::test_correctness_var_mean_float64 # This fails with a message like # 10845c10845 # < __global__ void nvfuser_inner_persistent_f7_c1_r0_g2(Tensor T0, Tensor T8, Tensor T7) { # --- # > __global__ void nvfuser_inner_persistent_f7_c1_r0_g2(Tensor T0, Tensor T7, Tensor T8) { # 10897c10897 # < T8[0] # --- # > T7[0] # 10923c10923 # < T7[0] # --- # > T8[0] # Diff of __tmp_kernel_inner_persistent_f7_c1_r0_g2.cu from rep 1 to rep 5 (above) is non-zero rm -rf /tmp/nvfuser_kernel_db export DEBUG_SERDE=disable tools/check_determinism.sh -- pytest -vs tests/python/test_ops.py::test_correctness_var_mean_float64 # This succeeds ``` Note that this script will delete any existing *.cu files in the current directory. --- tools/check_determinism.sh | 105 +++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100755 tools/check_determinism.sh diff --git a/tools/check_determinism.sh b/tools/check_determinism.sh new file mode 100755 index 00000000000..9b1d355afd5 --- /dev/null +++ b/tools/check_determinism.sh @@ -0,0 +1,105 @@ +#!/bin/bash + +set -e +set -o pipefail + +usage() { + echo "Usage: $0 [-h] [-n NUMREPS=10] -- command to run]" +} + + +while getopts "n:h" arg +do + case $arg in + n) + NUMREPS=$OPTARG + shift + ;; + h | ?) + usage + exit 1 + ;; + esac +done +# getopts stops parsing if it sees "--". We can detect that case and record command +while [[ $# -gt 0 ]] +do + if [[ "$1" == "--" ]] + then + shift + break + fi + shift +done +CMD=$* + +export NVFUSER_DUMP=cuda_to_file + +KERNELDIR=$(mktemp -d) + +cleanup() { + rm -rf "$KERNELDIR" +} + +trap "cleanup" EXIT + +FIRSTREPDIR="$KERNELDIR/1" + +retval=0 +for rep in $(seq 1 "$NUMREPS") +do + NUMEXISTINGCUFILES=$(find . -maxdepth 1 -name \*.cu | wc -l) + if [[ $NUMEXISTINGCUFILES -ne 0 ]] + then + KERNELBACKUPDIR=./check_determinism-kernelbackup$(date +%Y%m%d.%H%M%S) + echo "Backing up $NUMEXISTINGCUFILES existing .cu files to $KERNELBACKUPDIR" + mkdir -p "$KERNELBACKUPDIR" + mv ./*.cu "$KERNELBACKUPDIR" + fi + # $CMD does not need to succeed for us to analyze it + set +e + $CMD + set -e + + REPDIR="$KERNELDIR/$rep" + mkdir -p "$REPDIR" + mv ./*.cu "$REPDIR/" + + NUMFIRST=$(find "$FIRSTREPDIR" -name \*.cu | wc -l) + NUMREP=$(find "$REPDIR" -name \*.cu | wc -l) + if [[ $NUMREP -ne $NUMFIRST ]] + then + echo "Created $NUMFIRST kernels on first repetition and $NUMREP on repetition $rep" + retval=1 + fi + for newkernel in "$REPDIR"/*.cu + do + basename=$(basename "$newkernel") + firstkernel="$FIRSTREPDIR/$basename" + if [ ! -f "$firstkernel" ] + then + echo "Kernel file $newkernel in repetition $rep does not exist in first repetition" + retval=1 + continue + fi + set +e + diff "$firstkernel" "$newkernel" + diffstatus=$? + set -e + if [[ $diffstatus -ne 0 ]] + then + printf 'Diff of %s from rep 1 to rep %d (above) is non-zero\n' "$basename" "$rep" + retval=1 + continue + fi + done + if [[ $retval -ne 0 ]] + then + # Stop repetitions after first failure + break + else + echo "Generated kernels all match" + fi +done + +exit $retval From 621e1466844aea9c636a08508375c40a06459307 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 31 Oct 2024 08:50:06 -0700 Subject: [PATCH 11/15] Add information for coordinating segments in python frontend. (#3289) # Overview This PR adds information necessary for coordinating segments in the python frontend. Changes pulled from https://github.com/NVIDIA/Fuser/pull/3025. ## PR Details * Track the fusion state ids for the inputs, outputs, and extents of a Fusion. Inputs and extents are used to gather tensor arguments and scalars to run a fusion segment, while the outputs are employed to store results between segments. * A map from a CPP value to its corresponding fusion state id, which is needed to map values from original fusion to its segmented fusions. ## Implementation Details - `FusionState` is a lightweight representation of a CPP `Fusion`. - When calling `buildFusionIr`, a CPP `Fusion` is created from the Python `FusionDefinition`. At this point, the `FusionState` creates a mapping from CPP `Fusion` to its `State` objects. - However, the `FusionState` is temporary and the CPP `Fusion` is cached in `FusionCache`. The information linking the CPP `Fusion` and Python `FusionDefinition` is stored in `FusionCache`. - When we create a new `FusionState`, we look for a cached CPP `Fusion`. If it exists, we restore the mapping from the data stored in `FusionSchedules`. --- csrc/python_frontend/fusion_cache.cpp | 12 ++- csrc/python_frontend/fusion_cache.h | 8 ++ csrc/python_frontend/fusion_definition.cpp | 22 ++++++ csrc/python_frontend/fusion_record.h | 8 +- csrc/python_frontend/fusion_state.cpp | 89 +++++++++++++++++++++- csrc/python_frontend/fusion_state.h | 27 ++++++- csrc/python_frontend/python_bindings.cpp | 3 + nvfuser/__init__.py | 1 - tests/python/test_python_frontend.py | 36 +++++++++ 9 files changed, 194 insertions(+), 12 deletions(-) diff --git a/csrc/python_frontend/fusion_cache.cpp b/csrc/python_frontend/fusion_cache.cpp index 83ce851dbab..e95ee6820da 100644 --- a/csrc/python_frontend/fusion_cache.cpp +++ b/csrc/python_frontend/fusion_cache.cpp @@ -781,8 +781,8 @@ void FusionCache::deserialize(std::string filename) { NVF_CHECK( trie_ptr->fusion_id == fb_trie_node->fusion_id(), "The fusion id for this TrieNode should already be set.") - Fusion* fusion = - queryFusionSchedules(fb_trie_node->fusion_id())->preschedFusion(); + FusionSchedules* fs = queryFusionSchedules(fb_trie_node->fusion_id()); + Fusion* fusion = fs->preschedFusion(); try { // There could be bad fusion in the serialization. state->buildFusionIr(fusion); @@ -790,6 +790,14 @@ void FusionCache::deserialize(std::string filename) { // catch exception and setException for the terminal node trie_ptr->setException(e.what()); } + // The FusionState creates a mapping from CPP Fusion to its State objects. + // Since the CPP Fusion is cached in FusionCache and the FusionState is + // temporary, the information linking CPP Fusion and Python + // FusionDefinition is stored in FusionCache. + fs->inputs_fid_ = state->inputs(); + fs->outputs_fid_ = state->outputs(); + fs->extents_fid_ = state->extents(); + fs->map_value_to_fid_ = state->getValueMap(); } // Table TrieNode => Field: children: [ulong] diff --git a/csrc/python_frontend/fusion_cache.h b/csrc/python_frontend/fusion_cache.h index 190671b2b82..2d4f2533ba5 100644 --- a/csrc/python_frontend/fusion_cache.h +++ b/csrc/python_frontend/fusion_cache.h @@ -107,6 +107,14 @@ struct FusionSchedules { std::mutex scheds_lock; //! ID of fusion in python frontend fusion cache int64_t fusion_id_ = -1; + //! Fusion IDs of input arguments for FusionState + std::vector inputs_fid_; + //! IDs for Extents for TensorView input arguments for FusionState + std::vector extents_fid_; + //! Fusion IDs of output arguments for FusionState + std::vector outputs_fid_; + //! Map Fusion Val to its corresponding FusionDefinition index + std::unordered_map map_value_to_fid_; }; //! \struct TrieNode diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 09648a0bf36..05f12a7c2af 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -108,6 +108,17 @@ void FusionDefinition::finalizeDefinition() { throw; } + // The FusionState creates a mapping from CPP Fusion to its State objects. + // Since the CPP Fusion is cached in FusionCache and the FusionState is + // temporary, the information linking CPP Fusion and Python + // FusionDefinition is stored in FusionCache. + FusionSchedules* fs = + fusionCache()->queryFusionSchedules(fusion_id_.value()); + fs->inputs_fid_ = inputs(); + fs->outputs_fid_ = outputs(); + fs->extents_fid_ = extents(); + fs->map_value_to_fid_ = getValueMap(); + if (isDebugDumpEnabled(DebugDumpOption::FusionIrOriginal)) { printIr(); } @@ -121,6 +132,17 @@ void FusionDefinition::finalizeDefinition() { // build a proper fusion earlier. NVF_CHECK(!opt_e.has_value(), opt_e.value()); fusion_id_ = std::optional(trie_node_->fusion_id); + + // A CPP fusion already exists in the FusionCache for this FusionDefinition. + // In this case, a new CPP Fusion is not created, so the mapping from CPP + // fusion to Python FusionDefinition is not initialized. This state is + // stored within FusionSchedules and is retrieved for this FusionDefinition. + FusionSchedules* fs = + fusionCache()->queryFusionSchedules(fusion_id_.value()); + inputs_fid_ = fs->inputs_fid_; + outputs_fid_ = fs->outputs_fid_; + extents_fid_ = fs->extents_fid_; + map_value_to_fid_ = fs->map_value_to_fid_; } NVF_ERROR( diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 82879912509..154f8d28805 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -1368,7 +1368,7 @@ struct TensorRecord : RecordFunctor { } fd.setFusionState(outputs_.at(0).index, tv); - fd.addInput(tv); + fd.addInput(tv, outputs_.at(0).index); } void print(std::ostream& os, bool close_function = true) const final { @@ -1545,12 +1545,12 @@ struct OutputRecord : RecordFunctor { } tv_output->setAllocationDomain(allocation_domain, true); } - fd.addOutput(tv_output); + fd.addOutput(tv_output, args_.at(0).index); } else { NVF_CHECK( stride_order_.empty(), "stride_order can't be dictated for scalar outputs."); - fd.addOutput(output); + fd.addOutput(output, args_.at(0).index); } } } @@ -2015,7 +2015,7 @@ struct ScalarRecord : RecordFunctor { void operator()(FusionState& fd) final { Val* output = IrBuilder::create(value_, dtype_); if (!value_.hasValue()) { - fd.addInput(output); + fd.addInput(output, outputs_.at(0).index); } fd.setFusionState(outputs_.at(0).index, output); } diff --git a/csrc/python_frontend/fusion_state.cpp b/csrc/python_frontend/fusion_state.cpp index 99868f14b21..be8d8d0c514 100644 --- a/csrc/python_frontend/fusion_state.cpp +++ b/csrc/python_frontend/fusion_state.cpp @@ -85,6 +85,22 @@ std::unique_ptr FusionState::clone() { state->fusion_state_.insert( state->fusion_state_.end(), fusion_state_.begin(), fusion_state_.end()); state->num_recording_states_ = num_recording_states_; + std::copy( + inputs_fid_.begin(), + inputs_fid_.end(), + std::back_inserter(state->inputs_fid_)); + std::copy( + outputs_fid_.begin(), + outputs_fid_.end(), + std::back_inserter(state->outputs_fid_)); + std::copy( + extents_fid_.begin(), + extents_fid_.end(), + std::back_inserter(state->extents_fid_)); + std::copy( + map_value_to_fid_.begin(), + map_value_to_fid_.end(), + std::inserter(state->map_value_to_fid_, state->map_value_to_fid_.end())); return state; } @@ -108,6 +124,7 @@ void FusionState::buildFusionIr(Fusion* fusion) { e.what()); } } + addExtents(); } void FusionState::addRecord(RecordFunctor* record) { @@ -147,6 +164,10 @@ void FusionState::resetFusionState(Fusion* fusion, size_t size) { fusion_ = fusion; fusion_state_.clear(); fusion_state_.resize(size, {}); + inputs_fid_.clear(); + outputs_fid_.clear(); + extents_fid_.clear(); + map_value_to_fid_.clear(); } void FusionState::addFusionState(Val* val) { @@ -178,6 +199,7 @@ size_t FusionState::numFusionStates() const { void FusionState::setFusionState(size_t index, Val* val) { fusion_state_.at(index) = {val}; + map_value_to_fid_.emplace(val, (int64_t)index); } void FusionState::setFusionStateVector(size_t index, std::vector val) { @@ -189,14 +211,18 @@ void FusionState::setFusionStateVector(size_t index, std::vector val) { fusion_state_.at(index) = {val}; } -void FusionState::addInput(Val* input) { +void FusionState::addInput(Val* input, size_t index) { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); fusion_->addInput(input); + map_value_to_fid_.emplace(input, (int64_t)index); + inputs_fid_.push_back((int64_t)index); } -void FusionState::addOutput(Val* output) { +void FusionState::addOutput(Val* output, size_t index) { NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); fusion_->addOutput(output); + map_value_to_fid_.emplace(output, (int64_t)index); + outputs_fid_.push_back((int64_t)index); } void FusionState::aliasOutputToInput(Val* output, Val* input) { @@ -206,4 +232,63 @@ void FusionState::aliasOutputToInput(Val* output, Val* input) { fusion_->aliasOutputToInput(output, input, AllocationType::ReuseBuffer); } +const std::unordered_map& FusionState::getValueMap() + const { + return map_value_to_fid_; +} + +const std::vector& FusionState::inputs() const { + return inputs_fid_; +} + +const std::vector& FusionState::outputs() const { + return outputs_fid_; +} + +const std::vector& FusionState::extents() const { + return extents_fid_; +} + +std::vector FusionState::getExtents(Fusion* fusion) { + NVF_CHECK(fusion != nullptr, "Fusion is undefined."); + + std::vector extents; + for (Val* v : fusion->inputs()) { + // short-circuit: skip if not TensorView + if (!v->isA()) { + continue; + } + TensorView* tv = v->as(); + std::vector logical_dom = + TensorDomain::noReductions(tv->getLogicalDomain()); + std::transform( + logical_dom.begin(), + logical_dom.end(), + std::back_inserter(extents), + [](IterDomain* id) { return id->getMaybeExpandedExtent(); }); + } + return extents; +} + +void FusionState::addExtents() { + NVF_CHECK(fusion_ != nullptr, "Fusion is undefined."); + + // The size of the tensor dimensions can be used as an input of the + // segments. NvFuser does not support returning scalar values. Segmentation + // must pass those sizes as segment arguments manually. + std::vector extents = getExtents(fusion_); + for (Val* extent : extents) { + int64_t num_extents = (int64_t)extents_fid_.size(); + // Use negative numbers to represent extent of iterDomains to avoid conflict + // with non-negative numbers used for scalars, vectors, and tensors. + // The extents are ordered based on the order of the fusion's inputs. + int64_t extent_fid = -num_extents - 1; + extents_fid_.push_back(extent_fid); + // The extent can already exist in the fusion. However, since scalars cannot + // be passed between segments, always overwrited existing fids. The original + // fusion definition will provide scalar extents. + map_value_to_fid_[extent] = extent_fid; + } +} + } // namespace nvfuser::python_frontend diff --git a/csrc/python_frontend/fusion_state.h b/csrc/python_frontend/fusion_state.h index bd75f7af5d6..7a83886514a 100644 --- a/csrc/python_frontend/fusion_state.h +++ b/csrc/python_frontend/fusion_state.h @@ -79,12 +79,21 @@ class FusionState { NVF_API void setFusionStateVector(size_t index, std::vector val); //! Adds a Tensor/Scalar input to the Fusion object - NVF_API void addInput(Val* input); + NVF_API void addInput(Val* input, size_t index); //! Adds a Tensor/Scalar output to the Fusion object - NVF_API void addOutput(Val* output); + NVF_API void addOutput(Val* output, size_t index); //! Alias an Output to Input in the Fusion object NVF_API void aliasOutputToInput(Val* output, Val* input); + //! Get map between CPP Fusion and Python FusionDefinition + NVF_API const std::unordered_map& getValueMap() const; + //! Get indicies for the inputs of FusionState + NVF_API const std::vector& inputs() const; + //! Get indicies for the outputs of FusionState + NVF_API const std::vector& outputs() const; + //! Get indicies for the extents of TensorView inputs of FusionState + NVF_API const std::vector& extents() const; + //! Add a Record void addRecord(RecordFunctor* record); //! Builds an nvFuser Fusion IR object @@ -94,6 +103,10 @@ class FusionState { std::unique_ptr clone(); private: + //! Get extents for TensorView inputs in Fusion + std::vector getExtents(Fusion* fusion); + //! Add extents of TensorView inputs to FusionState + void addExtents(); //! Change the fusion ptr and reset its state void resetFusionState(Fusion* fusion, size_t size); @@ -104,10 +117,18 @@ class FusionState { std::vector> recording_; //! A vector of state that represents Tensors/Vectors/Scalars std::vector recording_state_; + //! Input arguments for FusionState + std::vector inputs_fid_; + //! Output arguments for FusionState + std::vector outputs_fid_; + //! Extents for TensorView input arguments for FusionState + std::vector extents_fid_; + //! Map Fusion Val to its corresponding FusionDefinition index + std::unordered_map map_value_to_fid_; private: //! A ptr to the container used when building the Fusion IR from a definition - Fusion* fusion_; + Fusion* fusion_ = nullptr; //! A vector of nvFuser Fusion IR TensorViews/Vectors/Scalars for building the //! Fusion IR graph. //! NOTE: Vectors are represented by a vector. This could diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index b229107c45b..79f460fa232 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -999,6 +999,9 @@ void initNvFuserPythonBindings(PyObject* module) { // Mark the end of a schedule inst::Trace::instance()->endEvent(nullptr); }) + .def("inputs", [](FusionDefinition& self) { return self.inputs(); }) + .def("outputs", [](FusionDefinition& self) { return self.outputs(); }) + .def("extents", [](FusionDefinition& self) { return self.extents(); }) .def( "__repr__", [](FusionDefinition& self) { diff --git a/nvfuser/__init__.py b/nvfuser/__init__.py index 4b4f25b9d66..7d9048e7bf6 100644 --- a/nvfuser/__init__.py +++ b/nvfuser/__init__.py @@ -53,7 +53,6 @@ class FusionDefinition(_C._FusionDefinition): def __init__(self, id=None, max_length=1024): super(FusionDefinition, self).__init__(id, max_length) self.profiled = False - self.inputs = None def __enter__(self): return self._setup_definition() diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index 874223471eb..e0597757a9c 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4601,6 +4601,42 @@ def fusion_func(fd: FusionDefinition) -> None: for out in nvf_out: self.assertTrue(out.allclose(x[:, 1:, 2:])) + def test_fusion_information(self): + inputs = [ + torch.ones(2, 4, 8, device="cuda"), + torch.ones(2, 4, 8, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition) -> None: + t0 = fd.from_pytorch(inputs[0]) + t1 = fd.from_pytorch(inputs[1]) + c2 = fd.define_scalar(3.0) + + t3 = fd.ops.add(t0, t1) + t4 = fd.ops.mul(t3, c2) + t5 = fd.ops.sum(t4, [-1], False, DataType.Float) + + fd.add_output(t5) + + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) + eager_out = torch.sum((inputs[0] + inputs[1]) * 3.0, dim=-1) + self.assertEqual(eager_out, nvf_out[0]) + + with FusionDefinition() as fd: + fusion_func(fd) + + nvf_out1 = fd.execute(inputs) + self.assertEqual(eager_out, nvf_out1[0]) + + # The input tensors are t0 and t1. + self.assertEqual(fd.inputs(), [0, 1]) + # The output tensors is t5. + self.assertEqual(fd.outputs(), [5]) + # The extents correspond with the dimensions for each input tensor. + # There are two input tensors with three dimensions each, so the + # extents range from [-1, -6]. + self.assertEqual(fd.extents(), [idx for idx in range(-1, -7, -1)]) + def test_issue_3292(self): inputs = [ torch.testing.make_tensor( From a59fd730c0a689521dd1bb59652233b000f97894 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 31 Oct 2024 09:34:21 -0700 Subject: [PATCH 12/15] Fix IterDomain::merge with expanded inner input (#3316) I believe this is just a trivial bug, and, luckily, I don't think it would actually affect anything. This could matter if an expanded iter domain got merged with a non-broadcast iter domain as part of a reshape op, but reshape converts expanded iter domains to non-broadcast iter domains, so this bug won't matter. In the case of normal scheduling, whether the output of a merge with an expanded broadcast is still a broadcast or not shouldn't matter, I believe. --- csrc/ir/nodes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 0091966e665..3c7be6b6262 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -2593,7 +2593,7 @@ IterDomain* IterDomain::merge( } else { expanded_extent = mul(outer->expandedExtent(), inner->extent()); } - } else if (outer->hasExpandedExtent() && inner->hasExpandedExtent()) { + } else if (!outer->hasExpandedExtent() && inner->hasExpandedExtent()) { if (outer->isBroadcast()) { expanded_extent = inner->expandedExtent(); } else { From a4b549a385a47523233bd48cdb99acbf877be49d Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 31 Oct 2024 09:46:07 -0700 Subject: [PATCH 13/15] Fix `is_clonable` in `test_issue_3292` (#3319) This PR adds `ceilDiv` to serde, which enables `is_clonable` in `test_issue_3292`. --- csrc/serde/fusion_record.cpp | 1 + tests/python/test_python_frontend.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index 5de2cda9873..7e23adf2b69 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -840,6 +840,7 @@ void RecordFunctorFactory::setupFunctionMaps() { NVFUSER_BINARY_TV_OP("bitwise_right_shift", bitwise_right_shift) NVFUSER_BINARY_TV_OP("logical_right_shift", logical_right_shift) NVFUSER_BINARY_TV_OP("gcd", gcd) + NVFUSER_BINARY_TV_OP("ceilDiv", ceilDiv) NVFUSER_BINARY_TV_ALPHA_OP("add_alpha", add_alpha) NVFUSER_BINARY_TV_ALPHA_OP("sub_alpha", sub_alpha) diff --git a/tests/python/test_python_frontend.py b/tests/python/test_python_frontend.py index e0597757a9c..0f2a9f9314d 100644 --- a/tests/python/test_python_frontend.py +++ b/tests/python/test_python_frontend.py @@ -4708,5 +4708,4 @@ def fusion_func(fd: FusionDefinition) -> None: T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0) fd.add_output(T223) - # is_clonable=False is because translation fails with missing ceilDiv - nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False) + nvf_out, _ = self.exec_nvfuser(fusion_func, inputs) From 4cf9533d5f6c927148ccd11bf0f48cb23c937e8c Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Thu, 31 Oct 2024 11:06:12 -0700 Subject: [PATCH 14/15] Use a single elect sync ite for all trasactions (#3314) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Before: ```C++ if (elect-sync) { arrive-expect-tx1; tma1; } if (elect-sync) { arrive-expect-tx2; tma2; } ``` After: ``` if (elect-sync) { arrive-expect-tx1; tma1; arrive-expect-tx2; tma2; } ``` Perf: ```markdown Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 36.0 151775 1 151775.0 151775.0 151775 151775 0.0 ::nvfuser_none_f0_c0_r0_g0(::Tensor<::__half, (int)3, (int)3>, … 20.7 87135 1 87135.0 87135.0 87135 87135 0.0 nvjet_hsh_256x128_64x4_1x2_h_bz_coopA_NTT ``` nvFuser/cuBLAS = `57.4%`. --- csrc/device_lower/pass/circular_buffer.cpp | 23 +++++++++++++++++----- csrc/index_compute.cpp | 1 + 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 40bb8fc3294..f2eb30297a5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -700,6 +700,17 @@ class CloneTmaCircularBufferLoopAndInsertSync return wait_exprs; } + // If there is already an if-then-else with electSync() predicate, use it. + // Otherwise, create a new one. + kir::IfThenElse* getElectSyncIfThenElse() { + if (elect_sync_if_then_else_ == nullptr) { + elect_sync_if_then_else_ = IrBuilder::create( + IrBuilder::create(PredicateType::ElectSync)); + for_loop_stack_.back()->body().push_back(elect_sync_if_then_else_); + } + return elect_sync_if_then_else_; + } + // This function selects a single thread to launch tma load and mbarrier // arrive_expected_tx operations. The remaining threads will simply arrive // at the mbarrier. @@ -719,16 +730,14 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(mbarrier_arrive_tx_ != nullptr); NVF_ERROR(expr != nullptr); - // Create the if-then-else with electSync() predicate for the arrive expect - // transaction. - kir::IfThenElse* if_expr = IrBuilder::create( - IrBuilder::create(PredicateType::ElectSync)); + // Use the if-then-else with electSync() predicate for the arrive expect + // and cpAsyncBulk operations. + kir::IfThenElse* if_expr = getElectSyncIfThenElse(); // A single thread issues arriveExpectTx with expected transactions and // launches the TMA load. if_expr->thenBody().push_back(mbarrier_arrive_tx_); if_expr->thenBody().push_back(expr); - for_loop_stack_.back()->body().push_back(if_expr); mbarrier_arrive_tx_ = nullptr; } @@ -841,6 +850,10 @@ class CloneTmaCircularBufferLoopAndInsertSync // Mbarrier_ArriveExpectTx to add to cloned_top_level_loop kir::MBarrierArriveExpectTx* mbarrier_arrive_tx_ = nullptr; + // ElectSync if-then-else for the cloned loop. We put all the circular buffer + // load TMA operations under this if-then-else. + kir::IfThenElse* elect_sync_if_then_else_ = nullptr; + // The circular buffered TVs for the loop being cloned std::unordered_set circular_buffer_load_tvs_; }; diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 81c27dddfc8..553787c66bf 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2238,6 +2238,7 @@ kir::TensorIndex* Index::getConsumerIndex( DataType as_type) { Val* index = nullptr; if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || + ir_utils::isCpAsyncBulkLoad(consumer->definition()) || (isIdModelOptionEnabled(IdModelEnableOption::ConsumerIndex) && GpuLower::current()->isTensorIndexerEnabled())) { index = GpuLower::current()->tensorIndexer().getLinearIndex( From abdc3e105267c5a0cb9b83f849d7e49ee0fe6cd6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 31 Oct 2024 14:16:22 -0400 Subject: [PATCH 15/15] Remove MmaOpDetails::input_layout and getInputLayout (#3322) There is no reason for us to check the Mma layout anymore when defining an MmaOp, since that is all handled in the scheduler now. I also added a test where a new batch dimension is broadcasted before defining the MmaOp. Fixes #2273. --- csrc/ir/utils.cpp | 58 ----------------------------------- csrc/ir/utils.h | 2 -- tests/cpp/test_matmul.cpp | 63 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 60 deletions(-) diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index d52c7858923..91a1170bf38 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1224,55 +1224,6 @@ TensorViewDetails getDetailsFor(const std::vector& dims) { return details; } -MmaLayout getInputLayout( - const TensorViewDetails& in_a, - const TensorViewDetails& in_b, - const MmaOp::AxesData& m_axes, - const MmaOp::AxesData& n_axes, - const MmaOp::AxesData& k_axes) { - // TT layout (b - broadcast, r - reduction): - // A = [M, K, b] - // B = [b, K, N] - // C = [M, r, N] (root domain) - if ((m_axes.front() < in_a.bcasts.front()) && - (k_axes.front() < in_a.bcasts.front()) && - (in_b.bcasts.front() < k_axes.front()) && - (in_b.bcasts.front() < n_axes.front())) { - return MmaLayout::TT; - } - // TN layout (b - broadcast, r - reduction): - // A = [M, b, K] - // B = [b, N, K] - // C = [M, N, r] (root domain) - if ((m_axes.front() < in_a.bcasts.front()) && - (in_a.bcasts.front() < k_axes.front()) && - (in_b.bcasts.front() < n_axes.front()) && - (in_b.bcasts.front() < k_axes.front())) { - return MmaLayout::TN; - } - // NT layout (b - broadcast, r - reduction): - // A = [K, M, b] - // B = [K, b, N] - // C = [r, M, N] (root domain) - if ((k_axes.front() < in_a.bcasts.front()) && - (m_axes.front() < in_a.bcasts.front()) && - (k_axes.front() < in_b.bcasts.front()) && - (in_b.bcasts.front() < n_axes.front())) { - return MmaLayout::NT; - } - // NN layout (b - broadcast, r - reduction): - // A = [b, K, M] - // B = [N, K, b] - // C = [N, r, M] (root domain) - if ((in_a.bcasts.front() < k_axes.front()) && - (k_axes.front() < m_axes.front()) && (n_axes.front() < k_axes.front()) && - (k_axes.front() < in_b.bcasts.front())) { - return MmaLayout::NN; - } - - NVF_THROW("Unsupported input layout"); -} - MmaOpDetails getMmaOpDetails( TensorView* out, TensorView* in_a, @@ -1405,15 +1356,6 @@ MmaOpDetails getMmaOpDetails( !details.k_axes.empty(), "MmaOp inputs must define at least a single K dimension"); - // TODO: for tensor contraction / split-k uses of MmaOp different input layout - // rules may be needed - details.input_layout = getInputLayout( - in_a_details, - in_b_details, - details.m_axes, - details.n_axes, - details.k_axes); - return details; } diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 74dcf5abb9d..b02fb2fbe3e 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -38,8 +38,6 @@ struct MmaOpDetails { // Concrete or broadcast axes that are present in all inputs // and output AxesData batch_axes; - // A placeholder for mma input layout - std::optional input_layout = std::nullopt; }; // A helper structure with pieces of information about TensorView diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 381fd095623..dcde07275d7 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -140,6 +140,69 @@ TEST_P(MatmulTestWithLayout, AmpereMatmul) { NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); } +// Single batch dimension which is broadcast +TEST_P(MatmulTestWithLayout, AmpereMatmulBroadcastBatch) { + // Keep multiples of 8 to keep vectorizable. + int M = 504, N = 136, K = 248; + + Fusion fusion; + FusionGuard fg(&fusion); + + auto shapes = matmulAtInputShape3DTuring(-1, -1, -1, layout); + + auto tv0 = makeContigConcreteTensor(shapes.first, DataType::Half); + auto tv1 = makeContigConcreteTensor(shapes.second, DataType::Half); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + tv0 = canonicalizeInputToBMNK(tv0, layout, MmaOperand::A); + tv1 = canonicalizeInputToBMNK(tv1, layout, MmaOperand::B); + // Broadcast inputs to 1, M, 1, K and 1, 1, N, K + tv0 = broadcast(tv0, {true, false, false, false}); + tv1 = broadcast(tv1, {true, false, false, false}); + auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + + fusion.addOutput(tv2); + + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 128, 32); + gemm_tile.warp_tile = GemmTile(64, 64, 32); + gemm_tile.instruction_tile = GemmTile(16, 8, 16); + + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 4}; + mparams.mma_macro = MmaMacro::Ampere_16_8_16; + mparams.tile_sizes = gemm_tile; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = true; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + auto inputs = matmulAtInput3DTuring(M, N, K, layout); + + FusionExecutor fe; + NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK( + 8, + 0, + fe.compileFusion( + &fusion, + {inputs.first, inputs.second}, + LaunchParams(), + matmul_cparams)); + ASSERT_TRUE(getBankConflictInfo(fe.kernel()).empty()); + ASSERT_FALSE( + PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(fe.kernel())); + auto cg_outputs = fe.runFusion({inputs.first, inputs.second}); + auto tref = + atMatmul( + inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout) + .unsqueeze(0); + NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001)); +} + TEST_P(MatmulTestWithLayout, AmperePrologueFusionBroadcast) { // Keep multiples of 8 to keep vectorizable. int M = 504, N = 136, K = 248;