diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 625bd95985e..6dd34f4cab9 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -967,11 +967,7 @@ void scheduleInnerOuterPersistentKernel( scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); const auto& unroll_vectorizable_cached_tvs = reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, - is_vectorize, - cached_inputs, - cached_outputs, - smem_consumers); + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); reduction_scheduler_utils::propagateParallelization( inner_reduction_tvs[0], inner_reference_tv, @@ -998,8 +994,7 @@ void scheduleInnerOuterPersistentKernel( outer_reference_tvs[i], is_vectorize, cached_inputs, - cached_outputs, - smem_consumers); + cached_outputs); reduction_scheduler_utils::propagateParallelization( outer_reduction_tvs[i], outer_reference_tvs[i], @@ -1044,6 +1039,13 @@ void scheduleInnerOuterPersistentKernel( } } + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + // Remove dummy outputs as they can inadvertently affect CA positions for (auto output : dummy_outputs) { fusion->removeOutput(output); diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index 7bf100adca3..2601fcc469c 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -1420,6 +1420,7 @@ void schedulePersistentKernel( unroll, vectorize, is_outer_grid_persistence, + rparams->unroll_factor_inner_reduction, reduction_tvs, cached_inputs, cached_outputs, diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 8cca314f9ae..87f9d2bffad 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -1259,6 +1259,7 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) { unroll, vectorize, use_iter_grouped_reduction, + rparams->unroll_factor_inner_reduction, reduction_tvs, cached_inputs, cached_outputs); diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index f0d03e02ed8..34db3133da7 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -5,14 +5,14 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include - #include #include #include #include #include +#include #include +#include #include #include #include @@ -348,6 +348,7 @@ void multiReductionInliner( const bool is_unroll_or_vectorization, const bool vectorize, const bool use_grouped_reduction, + const int64_t vectorizatoin_factor, std::vector reduction_tvs, std::vector cached_inputs, std::vector> cached_outputs, @@ -361,7 +362,7 @@ void multiReductionInliner( } const auto& unroll_vectorizable_cached_tvs = getCachedTvsToUnrollOrVectorize( - reference_tv, vectorize, cached_inputs, cached_outputs, smem_consumers); + reference_tv, vectorize, cached_inputs, cached_outputs); reduction_scheduler_utils::propagateParallelization( reduction_tv, reference_tv, @@ -370,6 +371,13 @@ void multiReductionInliner( reduction_tvs, unroll_vectorizable_cached_tvs); + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, vectorizatoin_factor); + } + // Remove dummy outputs as they can inadvertently affect CA positions for (auto output : dummy_outputs) { fusion->removeOutput(output); @@ -428,8 +436,7 @@ std::unordered_set getCachedTvsToUnrollOrVectorize( TensorView* reference_tv, bool vectorize, const std::vector& cached_inputs, - const std::vector>& cached_outputs, - const std::vector& smem_consumers) { + const std::vector>& cached_outputs) { auto reduced_tv = ir_utils::getSoleProducerTv(reference_tv); // Grab all tensor views that should be vectorized auto vectorizable_inputs_outputs = @@ -469,18 +476,6 @@ std::unordered_set getCachedTvsToUnrollOrVectorize( } } - if (vectorize) { - for (auto tv : smem_consumers) { - // smem_consumers were added in schedule process - // movePersistentBufferToSmem() using cacheAfter() - NVF_ERROR( - vectorizable_expr(tv->definition()), - "Expected a vectorizable expression, but got: ", - tv->definition()->toString()); - unroll_vectorizable_tvs.emplace(tv); - } - } - return unroll_vectorizable_tvs; } @@ -1009,5 +1004,53 @@ std::ostream& operator<<(std::ostream& os, ReductionType reduction_type) { return os; } +void sharedMemoryConsumerVectorization( + std::vector& smem_consumers, + int64_t io_vectorization_factor) { + for (auto tv : smem_consumers) { + // they were creatd with cacheAfter. + NVF_ERROR( + tv->definition()->isA(), + "smem consumers should be LoadStoreOp. Got: ", + tv->definition()->toString()); + + // non-concretized broadcast domains are moved to the innermost before + // transform propagation, should skip these axes. + int64_t vect_axis_pos = -1; + while (tv->axis(vect_axis_pos)->isBroadcast()) { + vect_axis_pos--; + NVF_ERROR( + vect_axis_pos + tv->nDims() >= 0, + "Out of bound access when visiting dim ", + vect_axis_pos, + " in Tv: ", + tv->toString()); + } + // they were transformed with innermost axis has extent equal to + // vectorization factor set for io tvs. + NVF_ERROR( + tv->axis(vect_axis_pos)->extent()->isConst(), + "Extent of the innermost axis of smem consumers should be constant. Got: ", + tv->toString()); + auto innermost_extent = + tv->axis(vect_axis_pos)->extent()->evaluate().as(); + NVF_ERROR( + innermost_extent == io_vectorization_factor, + "Extent of the innermost axis of smem consumers should be equal to the vectorization factor of fuion inputs and outputs. Got: ", + innermost_extent, + ", expected: ", + io_vectorization_factor); + auto dtype_bytes = dataTypeSize(tv->getDataType().value()); + auto max_vect_factor = + SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_bytes; + // additional split is added if the innermost extent is greater than max + // vectorization factor. + if (innermost_extent > max_vect_factor) { + tv->split(vect_axis_pos, max_vect_factor); + } + tv->axis(vect_axis_pos)->parallelize(ParallelType::Vectorize); + } +} + } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/csrc/scheduler/reduction_utils.h b/csrc/scheduler/reduction_utils.h index 78096210afb..713c399c03b 100644 --- a/csrc/scheduler/reduction_utils.h +++ b/csrc/scheduler/reduction_utils.h @@ -35,6 +35,7 @@ void multiReductionInliner( const bool unroll, const bool vectorize, const bool use_grouped_reduction, + const int64_t vectorizatoin_factor, std::vector reduction_tvs, std::vector cached_inputs, std::vector> cached_outputs, @@ -65,14 +66,11 @@ void propagateRFactor( // is_vectorize: Indicates if vectorization is applied in the scheduler. // cached_inputs: Inputs cached in registers or shared memory. // cached_outputs: Outputs cached in registers. -// smem_consumers: Consumers of shared memory persistent buffers, they are -// register cached Tvs after the shared memory tv. NVF_API std::unordered_set getCachedTvsToUnrollOrVectorize( TensorView* reference_tv, bool is_vectorize, const std::vector& cached_inputs, - const std::vector>& cached_outputs, - const std::vector& smem_consumers); + const std::vector>& cached_outputs); // Propagate parallelization from the reference TensorView to other TensorViews. // Unroll, Vectorize, and MisalignedVectorize types are explicitly handled for @@ -139,5 +137,23 @@ std::string toString(ReductionType reduction_type); ReductionType getReductionType(Fusion* fusion); ReductionType getReductionType(const std::vector& reduction_tvs); +/** + * @brief Vectorize shared memory consumers + * + * Applies vectorization to shared memory consumers. + * If extent of the last dim multiples vectorization factor exceeds hardware + * limitations, additional split is added. + * + * @param smem_consumers Vector of TensorView pointers representing shared + * memory consumers + * @param io_vectorization_factor Vectorization factor set for fusion inputs and + * outputs + * @note TODO: Optimize writing to shared memory and address bank conflicts for + * float32 with innermost extent of 8 + */ +void sharedMemoryConsumerVectorization( + std::vector& smem_consumers, + const int64_t io_vectorization_factor); + } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index c3ff5928742..2071aeb0e86 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -610,7 +610,7 @@ TEST_F(CombinedSchedulerTest, CombinedReduction) { false, inner_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_inner, true, cached_inputs, cached_outputs, {})); + reference_tv_inner, true, cached_inputs, cached_outputs)); reduction_scheduler_utils::propagateParallelization( outer_reduction_tv, reference_tv_outer, @@ -618,7 +618,7 @@ TEST_F(CombinedSchedulerTest, CombinedReduction) { false, outer_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_outer, true, cached_inputs, cached_outputs, {})); + reference_tv_outer, true, cached_inputs, cached_outputs)); inlineMost(); LaunchParams launch_constraints; @@ -773,7 +773,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) { false, inner_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_inner, true, cached_inputs, cached_outputs, {}), + reference_tv_inner, true, cached_inputs, cached_outputs), {selected_tvs_inner.begin(), selected_tvs_inner.end()}); const auto& selected_tvs_outer = @@ -787,7 +787,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) { false, outer_reduction_tvs, reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - reference_tv_outer, true, cached_inputs, cached_outputs, {}), + reference_tv_outer, true, cached_inputs, cached_outputs), {selected_tvs_outer.begin(), selected_tvs_outer.end()}); std::vector cached_gmem_temp{partialResult}; @@ -926,4 +926,72 @@ TEST_F(CombinedSchedulerTest, InnerOuterNoOuterBroadcastTv) { "", persistent_params->lparams); } + +// Reproduce error found in: +// thunder/tests/test_torch_compile_executor.py::test_torch_compile_cat_nvfuser_phi2_tanh +// Only happens when shared memory persistent is used. +TEST_F(CombinedSchedulerTest, SharedMemoryPersistentVectFactor) { + Fusion fusion; + FusionGuard fg(&fusion); + // When the input is float16, the vectorization factor is set to 8. + // If the persistent buffer tv1 is stored in shared memory and is not + // projected to inputs, the scheduler adds a cacheAfter to load tv1 from + // shared memory to registers in a vectorized manner, avoiding bank conflicts. + // However, since tv1 is float32, we can't directly use the vectorization + // factor set for float16 inputs because the maximum allowed vectorization + // width is 16 bytes. + const int dim0 = 1024; + const int dim1 = 4096; + auto dtype = DataType::Half; + auto tv0 = makeContigTensor(2, dtype); + fusion.addInput(tv0); + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = sum(tv1, {1}); + auto tv3 = broadcast(tv2, {false, true}); + auto tv4 = add(tv3, tv1); + auto tv5 = sum(tv1, {0}); + auto tv6 = castOp(DataType::Half, tv4); + auto tv7 = castOp(DataType::Half, tv5); + fusion.addOutput(tv6); + fusion.addOutput(tv7); + + Fusion fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({dim0, dim1}, options); + std::vector aten_inputs = {t0}; + + SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs); + ASSERT_TRUE(Schedule::canSchedule( + SchedulerType::InnerOuterPersistent, &fusion, runtime_info)); + auto scheduler = SchedulerEntry::makeSchedulerInstance( + SchedulerType::InnerOuterPersistent); + auto heuristic_params = scheduler->computeHeuristics(&fusion, runtime_info); + + // disable projection to inputs, so shared memory buffer is using float32 + heuristic_params->as()->project_persistent_buffers = false; + // Set vectorization factor to 8, so the exent of the innermost dimension + // exceed 16 bytes (8 x 4 = 32 bytes). + heuristic_params->as()->unroll_factor_inner_reduction = 8; + // when compute heuristics, the buffer is projected to inputs and the shared + // memory persistent buffer is the input, tv0. Then, we modified the + // heuristics to disable project to inputs, so needs to update the buffer + // being stored in shared memory to the original unprojected buffer, tv1. + heuristic_params->as()->smem_persistent_buffers = + std::vector{tv1}; + scheduler->schedule(&fusion, heuristic_params.get()); + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + + for (auto tv : fusion.allTvs()) { + if (tv->getMemoryType() == MemoryType::Shared) { + for (auto consumer : ir_utils::consumerTvsOf(tv)) { + EXPECT_TRUE(isVectorized(consumer)); + } + } + } + auto cg_outputs = fe.runFusion( + aten_inputs, heuristic_params->as()->lparams); + testValidate(&fusion_copy, cg_outputs, aten_inputs, __LINE__, __FILE__); +} } // namespace nvfuser