From 12568b60897cd97940557692937601b4605342fb Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 17 Oct 2024 16:02:30 -0700 Subject: [PATCH 01/17] unroll the outer dim --- csrc/scheduler/pointwise.cpp | 43 +++++++++++++++++++++--------------- tests/cpp/test_pointwise.cpp | 8 +++++-- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index be887e281f7..55216d09409 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -704,16 +704,19 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { reference_tv->split(1, pparams->vectorization_factor); } - // [outer | inner/vect, vect] + // [outer | i-remainder, Vect] reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); - // [outer | inner/vect/TIDx, TIDx, vect] - reference_tv->split(1, pparams->unroll_factor); - // [outer | inner/vect/TIDx/unroll, unroll, TIDx, vect] + // [outer | i-remainder, TIDx, Vect] + + reference_tv->split(0, pparams->unroll_factor); + // [o-remainder, Unroll| i-remainder, TIDx, Vect] reference_tv->split(0, 1); - // [outer, unswitch | inner/vect/TIDx/unroll, unroll, TIDx, vect] - reference_tv->reorder({{1, 2}}); - // [outer, i-remainder, unswitch, unroll, TIDx, vect] + // [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect] + + reference_tv->reorder({{3, 1}}); + // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] + reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Here we do not set axis(3)->parallelize(Unroll) because we do not want // it to be propagated. We manually unroll by splitting the inline @@ -724,21 +727,22 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { vectorize_id = reference_tv->axis(5); } - //[outer | i-remainder, Unswitch, Unroll, TIDx, vect] + // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] } // Move out of the way to furthest left point reference_tv->reorder({{1, 0}}); - - //[i-remainder | outer | Unswitch, Unroll, TIDx, vect] + // [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect] if (pparams->split_block) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy)); + // [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect] if (pparams->flip_grid_binding) { - // [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, vect] + // [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(1)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx] + // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx, + // Vect] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -747,11 +751,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 4; } } else { - // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx] + // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(0)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx] + // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx, + // Vect] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -761,9 +766,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } } } else { - // [BIDy | BIDx | Unswitch, Unroll, TIDx] + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] if (pparams->flip_grid_binding) { - // [BIDy | BIDx | Unswitch, Unroll, TIDx] + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(1)->parallelize(ParallelType::BIDx); if (pparams->split_grid_y_dim) { // [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx] @@ -771,18 +776,20 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 4; } else { + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(0)->parallelize(ParallelType::BIDy); unswitch_pos = 3; } } else { - // [BIDx | BIDy | Unswitch, Unroll, TIDx] + // [BIDx | BIDy | Unswitch, Unroll, TIDx, vect] reference_tv->axis(0)->parallelize(ParallelType::BIDx); if (pparams->split_grid_y_dim) { - // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx] + // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx, vect] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); unswitch_pos = 4; } else { + // [BIDx | BIDy | Unswitch, Unroll, TIDx, vect] reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 3; } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index f021c29a8e1..283c7101750 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -678,9 +678,11 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { auto tv3 = add(tv0, tv2); fusion->addOutput(tv3); + int dim0 = 1024; + int dim1 = 2048; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({1024, 2048}, options); - auto t1 = at::randn({2048}, options); + auto t0 = at::randn({dim0, dim1}, options); + auto t1 = at::randn({dim1}, options); std::vector runtime_inputs{t0, t1}; // generate heuristics @@ -700,6 +702,8 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { FusionExecutor fe; fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams); auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); + const auto& lparams = fe.lastLaunchParams(); + ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor); testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__); } } // namespace nvfuser From 09ba0b619cdaf38277e41a7a56ab02bc712e5021 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 17 Oct 2024 16:02:30 -0700 Subject: [PATCH 02/17] unroll the outer dim --- csrc/scheduler/pointwise.cpp | 43 +++++++++++++++++++++--------------- tests/cpp/test_pointwise.cpp | 8 +++++-- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index be887e281f7..55216d09409 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -704,16 +704,19 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { reference_tv->split(1, pparams->vectorization_factor); } - // [outer | inner/vect, vect] + // [outer | i-remainder, Vect] reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); - // [outer | inner/vect/TIDx, TIDx, vect] - reference_tv->split(1, pparams->unroll_factor); - // [outer | inner/vect/TIDx/unroll, unroll, TIDx, vect] + // [outer | i-remainder, TIDx, Vect] + + reference_tv->split(0, pparams->unroll_factor); + // [o-remainder, Unroll| i-remainder, TIDx, Vect] reference_tv->split(0, 1); - // [outer, unswitch | inner/vect/TIDx/unroll, unroll, TIDx, vect] - reference_tv->reorder({{1, 2}}); - // [outer, i-remainder, unswitch, unroll, TIDx, vect] + // [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect] + + reference_tv->reorder({{3, 1}}); + // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] + reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Here we do not set axis(3)->parallelize(Unroll) because we do not want // it to be propagated. We manually unroll by splitting the inline @@ -724,21 +727,22 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { vectorize_id = reference_tv->axis(5); } - //[outer | i-remainder, Unswitch, Unroll, TIDx, vect] + // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] } // Move out of the way to furthest left point reference_tv->reorder({{1, 0}}); - - //[i-remainder | outer | Unswitch, Unroll, TIDx, vect] + // [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect] if (pparams->split_block) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy)); + // [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect] if (pparams->flip_grid_binding) { - // [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, vect] + // [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(1)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx] + // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx, + // Vect] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -747,11 +751,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 4; } } else { - // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx] + // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(0)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx] + // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx, + // Vect] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -761,9 +766,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } } } else { - // [BIDy | BIDx | Unswitch, Unroll, TIDx] + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] if (pparams->flip_grid_binding) { - // [BIDy | BIDx | Unswitch, Unroll, TIDx] + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(1)->parallelize(ParallelType::BIDx); if (pparams->split_grid_y_dim) { // [i-remainder, BIDy{65535} | BIDx | Unswitch, Unroll, TIDx] @@ -771,18 +776,20 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 4; } else { + // [BIDy | BIDx | Unswitch, Unroll, TIDx, Vect] reference_tv->axis(0)->parallelize(ParallelType::BIDy); unswitch_pos = 3; } } else { - // [BIDx | BIDy | Unswitch, Unroll, TIDx] + // [BIDx | BIDy | Unswitch, Unroll, TIDx, vect] reference_tv->axis(0)->parallelize(ParallelType::BIDx); if (pparams->split_grid_y_dim) { - // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx] + // [BIDx | i-remainder, BIDy{65535} | Unswitch, Unroll, TIDx, vect] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); unswitch_pos = 4; } else { + // [BIDx | BIDy | Unswitch, Unroll, TIDx, vect] reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 3; } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index f021c29a8e1..283c7101750 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -678,9 +678,11 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { auto tv3 = add(tv0, tv2); fusion->addOutput(tv3); + int dim0 = 1024; + int dim1 = 2048; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - auto t0 = at::randn({1024, 2048}, options); - auto t1 = at::randn({2048}, options); + auto t0 = at::randn({dim0, dim1}, options); + auto t1 = at::randn({dim1}, options); std::vector runtime_inputs{t0, t1}; // generate heuristics @@ -700,6 +702,8 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { FusionExecutor fe; fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams); auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); + const auto& lparams = fe.lastLaunchParams(); + ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor); testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__); } } // namespace nvfuser From f5b349fc1a33631f1a6c8af034721d547f4bd774 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 18 Oct 2024 17:42:58 -0700 Subject: [PATCH 03/17] comment --- csrc/scheduler/pointwise_heuristic.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_heuristic.h b/csrc/scheduler/pointwise_heuristic.h index 9ca89919ad8..95a1112dd92 100644 --- a/csrc/scheduler/pointwise_heuristic.h +++ b/csrc/scheduler/pointwise_heuristic.h @@ -41,7 +41,10 @@ class PointwiseParams : public HeuristicParams { // vectorization factor int64_t vectorization_factor = 1; - // Unroll factor + + // Unroll on top of vectorization + // In the 2D scheduler, unroll the outer dimension to reuse loaded data across + // rows, reducing loaded bytes by the unroll factor. int64_t unroll_factor = 1; using HeuristicParams::HeuristicParams; From 23efc80326a26492e9f4987a0c9b7a6efb7ce1aa Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 18 Oct 2024 19:10:16 -0700 Subject: [PATCH 04/17] enable unroll --- csrc/scheduler/pointwise.cpp | 112 ++++++++++++++++++++++++++++------- 1 file changed, 89 insertions(+), 23 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 55216d09409..a2354a362fd 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -60,6 +60,86 @@ class DomainMap : public pointwise_utils::DomainMap { } }; +// Class to handle expensive operations information and calculation of unroll +// factors +class ExpensiveOpInfo { + public: + ExpensiveOpInfo() : n_tanh_(0), n_exp_(0), n_reciprocal_(0) {} + + void analyzeFusion(Fusion* fusion) { + for (auto expr : fusion->exprs()) { + if (auto unary = dynamic_cast(expr)) { + switch (unary->getUnaryOpType()) { + case UnaryOpType::Tanh: + n_tanh_++; + break; + case UnaryOpType::Exp: + n_exp_++; + break; + case UnaryOpType::Reciprocal: + n_reciprocal_++; + break; + default: + break; + } + } + } + } + + std::string toString() const { + std::stringstream ss; + ss << "ExpensiveOpInfo: {"; + ss << "n_tanh: " << n_tanh_ << ", "; + ss << "n_exp: " << n_exp_ << ", "; + ss << "n_reciprocal: " << n_reciprocal_ << "}"; + return ss.str(); + } + + int64_t getComputationFactor() const { + return f_tanh_ * n_tanh_ + f_exp_ * n_exp_ + f_reciprocal_ * n_reciprocal_; + } + + private: + // Number of each expensive operation in the fusion + int n_tanh_; + int n_exp_; + int n_reciprocal_; + + // Empirical factors to consider the cost of each operation + static constexpr int f_tanh_ = 4; + static constexpr int f_exp_ = 1; + static constexpr int f_reciprocal_ = 1; +}; + +int64_t getTargetUnrollFactor(Fusion* fusion, std::vector io_tvs) { + // Multiple loading instructions are issued if there are multiple input tvs, + // so we should reduce unroll factor. + int64_t n_inputs = 0; + for (auto tv : io_tvs) { + if (tv->isFusionInput()) { + n_inputs++; + } + } + + // Analyze the fusion to determine the number of expensive operations + // When computation is expensive, increase unroll to have more overlap between + // computation and memory access. + ExpensiveOpInfo eops; + eops.analyzeFusion(fusion); + std::cout << "Expensive operations: " << eops.toString() << std::endl; + // Empirical model based on experiment of pointwise gelu, silu, and mul. + // (1) start with 2 + int64_t base_factor = 2; + // (2) increase unroll factor if computation is expensive + int64_t computation_factor = eops.getComputationFactor(); + // (3) decrease unroll factor if there are multiple input tensors + int64_t input_factor = scheduler_utils::lastPow2(n_inputs); + // (4) Results: gelu: 2 * 4 / 1 = 8, silu: 2 * 2 / 2 = 2, mul: 2 + int64_t unroll_factor = base_factor * computation_factor / input_factor; + unroll_factor = std::max(unroll_factor, 1L); + return base_factor * computation_factor / n_inputs; +} + } // namespace std::unique_ptr getPointwiseHeuristics( @@ -182,18 +262,8 @@ std::unique_ptr getPointwiseHeuristics( largest_out, true, true)); }); - constexpr int64_t kSixteen = 16; // clang tidy - - auto max_vect_unroll_factor = ceilDiv( - // Available unrolling based on size of data type - (int64_t)kSixteen / max_input_dtype_size, - // Reduce max unrolling factor if we have many inputs/outputs to unroll - // as it could start consuming a lot of registers. - std::max( - (scheduler_utils::lastPow2( - (int64_t)vectorizable_inputs_outputs_entry.get().size()) >> - 2), - (int64_t)1)); + auto max_vect_unroll_factor = + getTargetUnrollFactor(fusion, vectorizable_inputs_outputs_entry.get()); // Don't unroll at the cost of getting a full wave on the GPU if (n_elems < device_multiprocessor_count * kThreadX && @@ -203,6 +273,7 @@ std::unique_ptr getPointwiseHeuristics( ceilDiv(n_elems, device_multiprocessor_count * kThreadX)); } + constexpr int64_t kSixteen = 16; // clang tidy auto max_vect_factor = std::min(kSixteen / max_input_dtype_size, max_vect_unroll_factor); @@ -318,7 +389,7 @@ std::unique_ptr getPointwiseHeuristics( // Need to be able to parallelize, don't use break if there's not // at least an unrolled warp. - if (ceilDiv(cur_right_elem_count, max_vect_unroll_factor) <= + if (ceilDiv(cur_right_elem_count, max_vect_factor) <= at::cuda::getCurrentDeviceProperties()->warpSize) { continue; } @@ -334,8 +405,8 @@ std::unique_ptr getPointwiseHeuristics( flip_grid_binding = false; } // Min transfer found, start setting values - bdimx = std::min( - ceilDiv(cur_right_elem_count, max_vect_unroll_factor), kThreadX); + bdimx = + std::min(ceilDiv(cur_right_elem_count, max_vect_factor), kThreadX); bdimy = 1; // Put remainder in bdimy if there's at least a wave of grid level // parallelism. @@ -344,7 +415,7 @@ std::unique_ptr getPointwiseHeuristics( } auto remainder_left = ceilDiv(cur_left_elem_count, bdimy); auto remainder_right = - ceilDiv(cur_right_elem_count, bdimx * max_vect_unroll_factor); + ceilDiv(cur_right_elem_count, bdimx * max_vect_factor); // Use this break point break_point = static_cast(break_point_i); min_total_transfer = cur_transfer_size; @@ -365,13 +436,8 @@ std::unique_ptr getPointwiseHeuristics( break_point, logical_reorder_map)); - // preserve the old heuristic where unroll is used only when vectorization is - // not used. should allow to use both unroll and vectorization together in - // heuristics tuning. - if (params->vectorization_factor == 1) { - params->unroll_factor = scheduler_utils::safeDiv( - max_vect_unroll_factor, params->vectorization_factor); - } + params->unroll_factor = scheduler_utils::safeDiv( + max_vect_unroll_factor, params->vectorization_factor); NVF_ERROR(right_elem_count > 0 || break_point == 0); NVF_ERROR(!(bdimy > 1 && gdim_right > 1)); From 67450af8d4ab07ca8ac70dc70048a83ab9ff12a8 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 19 Oct 2024 16:59:15 -0700 Subject: [PATCH 05/17] adjust bdimx for divisible split --- csrc/scheduler/pointwise.cpp | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index a2354a362fd..c709d3faa3e 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -96,7 +96,9 @@ class ExpensiveOpInfo { } int64_t getComputationFactor() const { - return f_tanh_ * n_tanh_ + f_exp_ * n_exp_ + f_reciprocal_ * n_reciprocal_; + auto factor = + n_tanh_ * f_tanh_ + n_exp_ * f_exp_ + n_reciprocal_ * f_reciprocal_; + return std::max(factor, 1); } private: @@ -137,7 +139,10 @@ int64_t getTargetUnrollFactor(Fusion* fusion, std::vector io_tvs) { // (4) Results: gelu: 2 * 4 / 1 = 8, silu: 2 * 2 / 2 = 2, mul: 2 int64_t unroll_factor = base_factor * computation_factor / input_factor; unroll_factor = std::max(unroll_factor, 1L); - return base_factor * computation_factor / n_inputs; + std::cout << "n_inputs: " << n_inputs << std::endl; + std::cout << "computation_factor: " << computation_factor << std::endl; + std::cout << "unroll_factor: " << unroll_factor << std::endl; + return unroll_factor; } } // namespace @@ -173,8 +178,10 @@ std::unique_ptr getPointwiseHeuristics( NVF_ERROR(largest_out != nullptr); + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t device_multiprocessor_count = - (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + (int64_t)dev_prop->multiProcessorCount; // TODO: Set to 1? int64_t max_input_dtype_size = 2; @@ -262,8 +269,11 @@ std::unique_ptr getPointwiseHeuristics( largest_out, true, true)); }); - auto max_vect_unroll_factor = + constexpr int64_t kSixteen = 16; // clang tidy + auto max_vect_factor = kSixteen / max_input_dtype_size; + auto max_unroll_factor = getTargetUnrollFactor(fusion, vectorizable_inputs_outputs_entry.get()); + auto max_vect_unroll_factor = max_vect_factor * max_unroll_factor; // Don't unroll at the cost of getting a full wave on the GPU if (n_elems < device_multiprocessor_count * kThreadX && @@ -273,10 +283,6 @@ std::unique_ptr getPointwiseHeuristics( ceilDiv(n_elems, device_multiprocessor_count * kThreadX)); } - constexpr int64_t kSixteen = 16; // clang tidy - auto max_vect_factor = - std::min(kSixteen / max_input_dtype_size, max_vect_unroll_factor); - // See pointwise.h to understand what we're doing for this 2D analysis. // Ideal break point location int break_point = 0; @@ -436,6 +442,24 @@ std::unique_ptr getPointwiseHeuristics( break_point, logical_reorder_map)); + // For 2D scheduler, try to avoid undivisible split by adjust bdimx, e.g. when + // right elem count is 1280, vectorized by 8, up to now, 2 blocks each with + // 128 threads is used, divisible split is achived if use 1 block with 160 + // threads (160 x 8 = 1280). On H100, perf increased from 68% SOL to 91% SOL. + // Set a max bimx to leave at least 2 blocks per SM to switch between each + // other when one block is stalled. + if (gdim_right > 1 && + right_elem_count % (bdimx * params->vectorization_factor) != 0) { + const int64_t max_bimdx = dev_prop->maxThreadsPerBlock / 2; + int64_t divisible_bimdx = right_elem_count / params->vectorization_factor; + divisible_bimdx = scheduler_utils::roundUpToN(divisible_bimdx, 32); + bdimx = divisible_bimdx > max_bimdx ? bdimx : divisible_bimdx; + gdim_right = + ceilDiv(right_elem_count / params->vectorization_factor, bdimx); + } + std::cout << "gdim_left: " << gdim_left << std::endl; + std::cout << "gdim_right: " << gdim_right << std::endl; + params->unroll_factor = scheduler_utils::safeDiv( max_vect_unroll_factor, params->vectorization_factor); From 35af0920e20548fe152c1e6feef5c7c803a0eff7 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sat, 19 Oct 2024 19:37:06 -0700 Subject: [PATCH 06/17] test heurs --- csrc/scheduler/pointwise.cpp | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index c709d3faa3e..80d527550bd 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -395,7 +395,7 @@ std::unique_ptr getPointwiseHeuristics( // Need to be able to parallelize, don't use break if there's not // at least an unrolled warp. - if (ceilDiv(cur_right_elem_count, max_vect_factor) <= + if (ceilDiv(cur_right_elem_count, max_vect_factor) < at::cuda::getCurrentDeviceProperties()->warpSize) { continue; } @@ -414,11 +414,11 @@ std::unique_ptr getPointwiseHeuristics( bdimx = std::min(ceilDiv(cur_right_elem_count, max_vect_factor), kThreadX); bdimy = 1; - // Put remainder in bdimy if there's at least a wave of grid level - // parallelism. - if (cur_left_elem_count > device_multiprocessor_count) { - bdimy = kThreadX / bdimx; - } + // // Put remainder in bdimy if there's at least a wave of grid level + // // parallelism. + // if (cur_left_elem_count > device_multiprocessor_count) { + // bdimy = kThreadX / bdimx; + // } auto remainder_left = ceilDiv(cur_left_elem_count, bdimy); auto remainder_right = ceilDiv(cur_right_elem_count, bdimx * max_vect_factor); @@ -433,7 +433,7 @@ std::unique_ptr getPointwiseHeuristics( } } - params->vectorization_factor = std::min( + int64_t vectorization_factor = std::min( max_vect_factor, vectorize_helper::getVectorizationFactor( runtime_info, @@ -441,16 +441,21 @@ std::unique_ptr getPointwiseHeuristics( data_cache, break_point, logical_reorder_map)); - + // // when right size is small, decrease vectorization + // while(bdimx * 2 <= kThreadX && vectorization_factor / 2 >= 2 ) { + // vectorization_factor /= 2; + // bdimx *= 2; + // } + params->vectorization_factor = vectorization_factor; // For 2D scheduler, try to avoid undivisible split by adjust bdimx, e.g. when // right elem count is 1280, vectorized by 8, up to now, 2 blocks each with // 128 threads is used, divisible split is achived if use 1 block with 160 // threads (160 x 8 = 1280). On H100, perf increased from 68% SOL to 91% SOL. - // Set a max bimx to leave at least 2 blocks per SM to switch between each + // Set a max bimx to leave at least 4 blocks per SM to switch between each // other when one block is stalled. if (gdim_right > 1 && right_elem_count % (bdimx * params->vectorization_factor) != 0) { - const int64_t max_bimdx = dev_prop->maxThreadsPerBlock / 2; + const int64_t max_bimdx = dev_prop->maxThreadsPerBlock / 4; int64_t divisible_bimdx = right_elem_count / params->vectorization_factor; divisible_bimdx = scheduler_utils::roundUpToN(divisible_bimdx, 32); bdimx = divisible_bimdx > max_bimdx ? bdimx : divisible_bimdx; From 089dd85f7234ed330f021c147720a0c6615a34c5 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Mon, 21 Oct 2024 17:53:59 -0700 Subject: [PATCH 07/17] unroll inner and outer --- benchmarks/cpp/utils.cpp | 7 +- csrc/python_frontend/python_bindings.cpp | 8 +- csrc/scheduler/pointwise.cpp | 154 ++++++++++++++--------- csrc/scheduler/pointwise_heuristic.h | 21 ++-- tests/cpp/test_pointwise.cpp | 6 +- 5 files changed, 118 insertions(+), 78 deletions(-) diff --git a/benchmarks/cpp/utils.cpp b/benchmarks/cpp/utils.cpp index c96a2e93456..017cf0968ee 100644 --- a/benchmarks/cpp/utils.cpp +++ b/benchmarks/cpp/utils.cpp @@ -97,9 +97,12 @@ std::string toString(const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { ss << "Vectorize, Factor: " << pparams->vectorization_factor << "\n"; } - if (pparams->unroll_factor > 1) { - ss << "Unroll, Factor: " << pparams->unroll_factor << "\n"; + if (pparams->unroll_factor_outer > 1) { + ss << "Outer Unroll, Factor: " << pparams->unroll_factor_outer << "\n"; } + if (pparams->unroll_factor_inner > 1) { + ss << "Inner Unroll, Factor: " << pparams->unroll_factor_inner << "\n"; + } return ss.str(); } diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 931e18b6dad..33b1df6fbbc 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -684,10 +684,10 @@ void initNvFuserPythonBindings(PyObject* module) { self.vectorization_factor = vectorization_factor_; }); pointwise_config.def_property( - "unroll_factor", - [](PointwiseParams& self) { return self.unroll_factor; }, - [](PointwiseParams& self, int64_t unroll_factor_) { - self.unroll_factor = unroll_factor_; + "unroll_factor_outer", + [](PointwiseParams& self) { return self.unroll_factor_outer; }, + [](PointwiseParams& self, int64_t unroll_factor_outer_) { + self.unroll_factor_outer = unroll_factor_outer_; }); pointwise_config.def( "__repr__", [](const PointwiseParams& self) { return self.toString(); }); diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 80d527550bd..82d6194b107 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -98,7 +98,8 @@ class ExpensiveOpInfo { int64_t getComputationFactor() const { auto factor = n_tanh_ * f_tanh_ + n_exp_ * f_exp_ + n_reciprocal_ * f_reciprocal_; - return std::max(factor, 1); + factor = std::max(factor, 1); + return factor; } private: @@ -139,10 +140,7 @@ int64_t getTargetUnrollFactor(Fusion* fusion, std::vector io_tvs) { // (4) Results: gelu: 2 * 4 / 1 = 8, silu: 2 * 2 / 2 = 2, mul: 2 int64_t unroll_factor = base_factor * computation_factor / input_factor; unroll_factor = std::max(unroll_factor, 1L); - std::cout << "n_inputs: " << n_inputs << std::endl; - std::cout << "computation_factor: " << computation_factor << std::endl; - std::cout << "unroll_factor: " << unroll_factor << std::endl; - return unroll_factor; + return base_factor * computation_factor / n_inputs; } } // namespace @@ -155,7 +153,7 @@ std::unique_ptr getPointwiseHeuristics( // Incase any buffer is of type DataType::Index const auto index_type = runtime_info.getIndexType(); - + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); auto params = std::make_unique(); params->tag = "Pointwise heuristics"; params->cparams.index_type = index_type; @@ -178,10 +176,8 @@ std::unique_ptr getPointwiseHeuristics( NVF_ERROR(largest_out != nullptr); - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t device_multiprocessor_count = - (int64_t)dev_prop->multiProcessorCount; + (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; // TODO: Set to 1? int64_t max_input_dtype_size = 2; @@ -274,7 +270,6 @@ std::unique_ptr getPointwiseHeuristics( auto max_unroll_factor = getTargetUnrollFactor(fusion, vectorizable_inputs_outputs_entry.get()); auto max_vect_unroll_factor = max_vect_factor * max_unroll_factor; - // Don't unroll at the cost of getting a full wave on the GPU if (n_elems < device_multiprocessor_count * kThreadX && max_vect_unroll_factor > 1) { @@ -366,6 +361,11 @@ std::unique_ptr getPointwiseHeuristics( continue; } + std::cout << "break_point_i: " << break_point_i + << ", cur_left_elem_count: " << cur_left_elem_count + << ", cur_right_elem_count: " << cur_right_elem_count + << std::endl; + auto lhs_byte_multiple = broadcast_byte_multiples[break_point_i].lhs_multiple; auto rhs_byte_multiple = @@ -373,18 +373,24 @@ std::unique_ptr getPointwiseHeuristics( // Estimate transfer cost with this break point int64_t cur_transfer_size = 1; + int64_t left_transfer_size = 1; int64_t right_transfer_size = 1; for (const auto left_i : c10::irange(break_point_i)) { - cur_transfer_size = - cur_transfer_size * elem_counts[left_i] * lhs_byte_multiple; + left_transfer_size = + left_transfer_size * elem_counts[left_i] * lhs_byte_multiple; } for (const auto right_i : c10::irange(break_point_i, ref_root.size())) { right_transfer_size = right_transfer_size * elem_counts[right_i] * rhs_byte_multiple; } - cur_transfer_size *= right_transfer_size; + cur_transfer_size = left_transfer_size * right_transfer_size; + + std::cout << "left_transfer_size: " << left_transfer_size + << ", right_transfer_size: " << right_transfer_size + << ", cur_transfer_size: " << cur_transfer_size + << ", transfer_size_1d: " << transfer_size_1d << std::endl; // Continue if this break point doesn't save at least 10% of 1D // scheduling or isn't better than previous break_points found. @@ -395,11 +401,10 @@ std::unique_ptr getPointwiseHeuristics( // Need to be able to parallelize, don't use break if there's not // at least an unrolled warp. - if (ceilDiv(cur_right_elem_count, max_vect_factor) < + if (ceilDiv(cur_right_elem_count, max_vect_factor) <= at::cuda::getCurrentDeviceProperties()->warpSize) { continue; } - // If outer broadcast, or balanced broadcast: if (lhs_byte_multiple <= rhs_byte_multiple && // If right transfer size is bigger than half of L2 @@ -414,11 +419,11 @@ std::unique_ptr getPointwiseHeuristics( bdimx = std::min(ceilDiv(cur_right_elem_count, max_vect_factor), kThreadX); bdimy = 1; - // // Put remainder in bdimy if there's at least a wave of grid level - // // parallelism. - // if (cur_left_elem_count > device_multiprocessor_count) { - // bdimy = kThreadX / bdimx; - // } + // Put remainder in bdimy if there's at least a wave of grid level + // parallelism. + if (cur_left_elem_count > device_multiprocessor_count) { + bdimy = kThreadX / bdimx; + } auto remainder_left = ceilDiv(cur_left_elem_count, bdimy); auto remainder_right = ceilDiv(cur_right_elem_count, bdimx * max_vect_factor); @@ -433,7 +438,7 @@ std::unique_ptr getPointwiseHeuristics( } } - int64_t vectorization_factor = std::min( + params->vectorization_factor = std::min( max_vect_factor, vectorize_helper::getVectorizationFactor( runtime_info, @@ -441,32 +446,41 @@ std::unique_ptr getPointwiseHeuristics( data_cache, break_point, logical_reorder_map)); - // // when right size is small, decrease vectorization - // while(bdimx * 2 <= kThreadX && vectorization_factor / 2 >= 2 ) { - // vectorization_factor /= 2; - // bdimx *= 2; - // } - params->vectorization_factor = vectorization_factor; - // For 2D scheduler, try to avoid undivisible split by adjust bdimx, e.g. when - // right elem count is 1280, vectorized by 8, up to now, 2 blocks each with - // 128 threads is used, divisible split is achived if use 1 block with 160 - // threads (160 x 8 = 1280). On H100, perf increased from 68% SOL to 91% SOL. - // Set a max bimx to leave at least 4 blocks per SM to switch between each - // other when one block is stalled. - if (gdim_right > 1 && - right_elem_count % (bdimx * params->vectorization_factor) != 0) { - const int64_t max_bimdx = dev_prop->maxThreadsPerBlock / 4; - int64_t divisible_bimdx = right_elem_count / params->vectorization_factor; - divisible_bimdx = scheduler_utils::roundUpToN(divisible_bimdx, 32); - bdimx = divisible_bimdx > max_bimdx ? bdimx : divisible_bimdx; - gdim_right = - ceilDiv(right_elem_count / params->vectorization_factor, bdimx); - } - std::cout << "gdim_left: " << gdim_left << std::endl; - std::cout << "gdim_right: " << gdim_right << std::endl; - params->unroll_factor = scheduler_utils::safeDiv( + // limit unroll factor when n_elems is small (e.g. less than 16K x 4K on H100) + // to use at least 8 waves to benefit from Thread-Level-Parallelism. Ideally, + // target wave depends on hardware, when computation latency is close to + // memory latency a smaller wave can be used. + const int64_t target_waves = 8L; + int64_t max_block_per_sm = + (int64_t)dev_prop->maxThreadsPerMultiProcessor / kThreadX; + int64_t total_blocks = break_point > 0 + ? gdim_left * gdim_right + : ceilDiv(n_elems / max_vect_factor, kThreadX); + int64_t n_waves_wo_unroll = + ceilDiv(total_blocks, max_block_per_sm * device_multiprocessor_count); + int64_t n_elems_limited_unroll = ceilDiv(n_waves_wo_unroll, target_waves); + int64_t resource_limited_unroll = scheduler_utils::safeDiv( max_vect_unroll_factor, params->vectorization_factor); + // don't unroll if unroll is input size limited and split is not divisible + if (n_elems_limited_unroll < resource_limited_unroll) { + bool divisible_split = break_point > 0 + ? (right_elem_count % (params->vectorization_factor * bdimx) == 0) + : (n_elems % (params->vectorization_factor * kThreadX) == 0); + n_elems_limited_unroll = divisible_split ? n_elems_limited_unroll : 1; + } + std::cout << "n_elems_limited_unroll: " << n_elems_limited_unroll + << ", resource limited unroll: " << resource_limited_unroll + << std::endl; + int64_t target_unroll_factor = + std::min(n_elems_limited_unroll, resource_limited_unroll); + + // int64_t =scheduler_utils::safeDiv( + // max_vect_unroll_factor, params->vectorization_factor); + + + params->unroll_factor_outer = 2; + params->unroll_factor_inner = 3; NVF_ERROR(right_elem_count > 0 || break_point == 0); NVF_ERROR(!(bdimy > 1 && gdim_right > 1)); @@ -489,7 +503,8 @@ std::unique_ptr getPointwiseHeuristics( << "num_elems: " << n_elems << "\n" << "elem_counts: " << elem_counts << "\n" << "max_input_dtype_size: " << max_input_dtype_size << "\n" - << "unroll_factor: " << params->unroll_factor << std::endl + << "unroll_factor_outer: " << params->unroll_factor_outer + << std::endl << "vectorize_factor: " << params->vectorization_factor << std::endl << "\n" << "logical_reorder_map: "; @@ -780,7 +795,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->reorder({{lhs_i, 0}, {-1, 1}}); // vectorization without unroll - if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) { + if (pparams->unroll_factor_outer == 1 && + pparams->unroll_factor_inner == 1 && + pparams->vectorization_factor > 1) { reference_tv->split(1, pparams->vectorization_factor); reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); reference_tv->split(0, 1); @@ -803,14 +820,24 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); // [outer | i-remainder, TIDx, Vect] - reference_tv->split(0, pparams->unroll_factor); - // [o-remainder, Unroll| i-remainder, TIDx, Vect] + std::cout << "reference_tv: " << reference_tv->toString() << std::endl; + if (pparams->unroll_factor_outer > 1) { + reference_tv->split(0, pparams->unroll_factor_outer); + std::cout << "reference_tv: " << reference_tv->toString() << std::endl; + } + if (pparams->unroll_factor_inner > 1) { + reference_tv->split(-3, pparams->unroll_factor_inner); + std::cout << "reference_tv: " << reference_tv->toString() << std::endl; + } + // [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] reference_tv->split(0, 1); - // [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect] + // [o-remainder, Unswitch, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] - reference_tv->reorder({{3, 1}}); - // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] + int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2; + reference_tv->reorder({{i_remainder_pos, 1}}); + // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] + std::cout << "reference_tv: " << reference_tv->toString() << std::endl; reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Here we do not set axis(3)->parallelize(Unroll) because we do not want @@ -818,11 +845,13 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // propagation process into two steps: // step 1: inline at the unswitch position for cached inputs and outputs // step 2: inline at the inner most dim for the rest of the graph - reference_tv->axis(4)->parallelize(ParallelType::TIDx); + int tidx_pos = pparams->vectorization_factor > 1 ? -2 : -1; + reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); if (pparams->vectorization_factor > 1) { - vectorize_id = reference_tv->axis(5); + vectorize_id = reference_tv->axis(-1); } - // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] + std::cout << "reference_tv: " << reference_tv->toString() << std::endl; + // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] } // Move out of the way to furthest left point @@ -899,7 +928,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // unmerged...] reference_tv->reorder({{-1, 0}}); - if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) { + if (pparams->unroll_factor_inner == 1 && + pparams->vectorization_factor > 1) { // Vectorize reference_tv->split(0, pparams->vectorization_factor); // Unswitch @@ -911,7 +941,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->axis(1)->parallelize(ParallelType::TIDx); reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Vectorization are propagated separately - vectorize_id = reference_tv->axis(3); + vectorize_id = reference_tv->axis(-1); //[BIDx, TIDx, Unswitch, Vectorization] // To make consistent with unrolling: @@ -925,7 +955,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // Threads reference_tv->split(0, kThreadX); // Unroll - reference_tv->split(0, pparams->unroll_factor); + if (pparams->unroll_factor_inner > 1) { + reference_tv->split(0, pparams->unroll_factor_inner); + } // Unswitch reference_tv->split(0, 1); @@ -938,9 +970,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // step 1: inline at the unswitch position for cached inputs and outputs // step 2: inline at the inner most dim for the rest of the graph reference_tv->axis(3)->parallelize(ParallelType::TIDx); - if (pparams->vectorization_factor > 1) { - vectorize_id = reference_tv->axis(4); - } + vectorize_id = reference_tv->axis(-1); } unswitch_pos = 2; } @@ -1034,4 +1064,4 @@ void PointWiseScheduler::schedule( schedulePointwise(fusion, pparams); } -} // namespace nvfuser +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/scheduler/pointwise_heuristic.h b/csrc/scheduler/pointwise_heuristic.h index 95a1112dd92..e587c17c21a 100644 --- a/csrc/scheduler/pointwise_heuristic.h +++ b/csrc/scheduler/pointwise_heuristic.h @@ -45,7 +45,13 @@ class PointwiseParams : public HeuristicParams { // Unroll on top of vectorization // In the 2D scheduler, unroll the outer dimension to reuse loaded data across // rows, reducing loaded bytes by the unroll factor. - int64_t unroll_factor = 1; + // Always equals 1 for 1D scheduler. + int64_t unroll_factor_outer = 1; + + // In the 2D scheduler, unroll the inner dimension to reuse loaded data across + // cols, reducing loaded bytes by the unroll factor. + // Also used in 1D scheduler. + int64_t unroll_factor_inner = 1; using HeuristicParams::HeuristicParams; @@ -60,7 +66,8 @@ class PointwiseParams : public HeuristicParams { other->break_point == break_point && other->split_block == split_block && other->split_grid_y_dim == split_grid_y_dim && - other->unroll_factor == unroll_factor && + other->unroll_factor_outer == unroll_factor_outer && + other->unroll_factor_inner == unroll_factor_inner && other->flip_grid_binding == flip_grid_binding; return attr_equal; } @@ -81,12 +88,9 @@ class PointwiseParams : public HeuristicParams { ss << " Split y grid dim\n"; } } - if (vectorization_factor > 1) { ss << "Vectorize, Factor: " << vectorization_factor << "\n"; - } - if (unroll_factor > 1) { - ss << "Unroll, Factor: " << unroll_factor << "\n"; - } + ss << "unroll_factor_outer: " << unroll_factor_outer << "\n"; + ss << "unroll_factor_inner: " << unroll_factor_inner << "\n"; if (flip_grid_binding) { ss << "Flip BIDx/BIDy bindings\n"; } @@ -100,7 +104,8 @@ class PointwiseParams : public HeuristicParams { static_cast(break_point) << 4 ^ static_cast(split_block) << 5 ^ static_cast(split_grid_y_dim) << 6 ^ - static_cast(unroll_factor) << 9 ^ + static_cast(unroll_factor_outer) << 7 ^ + static_cast(unroll_factor_inner) << 9 ^ static_cast(flip_grid_binding) << 10; return attr_hash; } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 283c7101750..9baa8d542f4 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -695,15 +695,17 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { // modify heuristics to enforce unroll on top of vectorization pparams->vectorization_factor = 4; - pparams->unroll_factor = 2; + pparams->unroll_factor_outer = 2; + pparams->unroll_factor_inner = 2; // schedule, compile, run, validate scheduler_instance->schedule(fusion.get(), pparams); + fusion->printMath(); FusionExecutor fe; fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams); auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); const auto& lparams = fe.lastLaunchParams(); - ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor); + ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor_outer); testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__); } } // namespace nvfuser From 00571b40cd2b13c711567f6aaa5bc490a8f22137 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Thu, 24 Oct 2024 19:26:55 -0700 Subject: [PATCH 08/17] wip --- tests/cpp/test_pointwise.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 9baa8d542f4..93ca278cc9a 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -706,6 +706,7 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); const auto& lparams = fe.lastLaunchParams(); ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor_outer); + ASSERT_EQ(lparams.gdimx(), dim1 / pparams->vectorization_factor / lparams.bdimx() / pparams->unroll_factor_outer); testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__); } } // namespace nvfuser From 377b7fc202142107cb54bd9926a1d4b6955eeb87 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 25 Oct 2024 09:48:22 -0700 Subject: [PATCH 09/17] tests --- benchmarks/cpp/utils.cpp | 2 +- csrc/scheduler/pointwise.cpp | 217 +++++++-------------------- csrc/scheduler/pointwise_heuristic.h | 6 +- tests/cpp/test_pointwise.cpp | 41 +++-- 4 files changed, 92 insertions(+), 174 deletions(-) diff --git a/benchmarks/cpp/utils.cpp b/benchmarks/cpp/utils.cpp index 017cf0968ee..e171badd9ae 100644 --- a/benchmarks/cpp/utils.cpp +++ b/benchmarks/cpp/utils.cpp @@ -102,7 +102,7 @@ std::string toString(const PointwiseParams* pparams) { } if (pparams->unroll_factor_inner > 1) { ss << "Inner Unroll, Factor: " << pparams->unroll_factor_inner << "\n"; - } + } return ss.str(); } diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 4661d478b31..b4c215ab0a7 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -61,89 +61,6 @@ class DomainMap : public pointwise_utils::DomainMap { } }; -// Class to handle expensive operations information and calculation of unroll -// factors -class ExpensiveOpInfo { - public: - ExpensiveOpInfo() : n_tanh_(0), n_exp_(0), n_reciprocal_(0) {} - - void analyzeFusion(Fusion* fusion) { - for (auto expr : fusion->exprs()) { - if (auto unary = dynamic_cast(expr)) { - switch (unary->getUnaryOpType()) { - case UnaryOpType::Tanh: - n_tanh_++; - break; - case UnaryOpType::Exp: - n_exp_++; - break; - case UnaryOpType::Reciprocal: - n_reciprocal_++; - break; - default: - break; - } - } - } - } - - std::string toString() const { - std::stringstream ss; - ss << "ExpensiveOpInfo: {"; - ss << "n_tanh: " << n_tanh_ << ", "; - ss << "n_exp: " << n_exp_ << ", "; - ss << "n_reciprocal: " << n_reciprocal_ << "}"; - return ss.str(); - } - - int64_t getComputationFactor() const { - auto factor = - n_tanh_ * f_tanh_ + n_exp_ * f_exp_ + n_reciprocal_ * f_reciprocal_; - factor = std::max(factor, 1); - return factor; - } - - private: - // Number of each expensive operation in the fusion - int n_tanh_; - int n_exp_; - int n_reciprocal_; - - // Empirical factors to consider the cost of each operation - static constexpr int f_tanh_ = 4; - static constexpr int f_exp_ = 1; - static constexpr int f_reciprocal_ = 1; -}; - -int64_t getTargetUnrollFactor(Fusion* fusion, std::vector io_tvs) { - // Multiple loading instructions are issued if there are multiple input tvs, - // so we should reduce unroll factor. - int64_t n_inputs = 0; - for (auto tv : io_tvs) { - if (tv->isFusionInput()) { - n_inputs++; - } - } - - // Analyze the fusion to determine the number of expensive operations - // When computation is expensive, increase unroll to have more overlap between - // computation and memory access. - ExpensiveOpInfo eops; - eops.analyzeFusion(fusion); - std::cout << "Expensive operations: " << eops.toString() << std::endl; - // Empirical model based on experiment of pointwise gelu, silu, and mul. - // (1) start with 2 - int64_t base_factor = 2; - // (2) increase unroll factor if computation is expensive - int64_t computation_factor = eops.getComputationFactor(); - // (3) decrease unroll factor if there are multiple input tensors - int64_t input_factor = scheduler_utils::lastPow2(n_inputs); - // (4) Results: gelu: 2 * 4 / 1 = 8, silu: 2 * 2 / 2 = 2, mul: 2 - int64_t unroll_factor = base_factor * computation_factor / input_factor; - unroll_factor = std::max(unroll_factor, 1L); - return base_factor * computation_factor / n_inputs; -} - } // namespace std::unique_ptr getPointwiseHeuristics( @@ -154,7 +71,7 @@ std::unique_ptr getPointwiseHeuristics( // Incase any buffer is of type DataType::Index const auto index_type = runtime_info.getIndexType(); - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + auto params = std::make_unique(); params->tag = "Pointwise heuristics"; params->cparams.index_type = index_type; @@ -267,10 +184,18 @@ std::unique_ptr getPointwiseHeuristics( }); constexpr int64_t kSixteen = 16; // clang tidy - auto max_vect_factor = kSixteen / max_input_dtype_size; - auto max_unroll_factor = - getTargetUnrollFactor(fusion, vectorizable_inputs_outputs_entry.get()); - auto max_vect_unroll_factor = max_vect_factor * max_unroll_factor; + + auto max_vect_unroll_factor = ceilDiv( + // Available unrolling based on size of data type + (int64_t)kSixteen / max_input_dtype_size, + // Reduce max unrolling factor if we have many inputs/outputs to unroll + // as it could start consuming a lot of registers. + std::max( + (scheduler_utils::lastPow2( + (int64_t)vectorizable_inputs_outputs_entry.get().size()) >> + 2), + (int64_t)1)); + // Don't unroll at the cost of getting a full wave on the GPU if (n_elems < device_multiprocessor_count * kThreadX && max_vect_unroll_factor > 1) { @@ -279,6 +204,9 @@ std::unique_ptr getPointwiseHeuristics( ceilDiv(n_elems, device_multiprocessor_count * kThreadX)); } + auto max_vect_factor = + std::min(kSixteen / max_input_dtype_size, max_vect_unroll_factor); + // See pointwise.h to understand what we're doing for this 2D analysis. // Ideal break point location int break_point = 0; @@ -362,11 +290,6 @@ std::unique_ptr getPointwiseHeuristics( continue; } - std::cout << "break_point_i: " << break_point_i - << ", cur_left_elem_count: " << cur_left_elem_count - << ", cur_right_elem_count: " << cur_right_elem_count - << std::endl; - auto lhs_byte_multiple = broadcast_byte_multiples[break_point_i].lhs_multiple; auto rhs_byte_multiple = @@ -374,24 +297,18 @@ std::unique_ptr getPointwiseHeuristics( // Estimate transfer cost with this break point int64_t cur_transfer_size = 1; - int64_t left_transfer_size = 1; int64_t right_transfer_size = 1; for (const auto left_i : c10::irange(break_point_i)) { - left_transfer_size = - left_transfer_size * elem_counts[left_i] * lhs_byte_multiple; + cur_transfer_size = + cur_transfer_size * elem_counts[left_i] * lhs_byte_multiple; } for (const auto right_i : c10::irange(break_point_i, ref_root.size())) { right_transfer_size = right_transfer_size * elem_counts[right_i] * rhs_byte_multiple; } - cur_transfer_size = left_transfer_size * right_transfer_size; - - std::cout << "left_transfer_size: " << left_transfer_size - << ", right_transfer_size: " << right_transfer_size - << ", cur_transfer_size: " << cur_transfer_size - << ", transfer_size_1d: " << transfer_size_1d << std::endl; + cur_transfer_size *= right_transfer_size; // Continue if this break point doesn't save at least 10% of 1D // scheduling or isn't better than previous break_points found. @@ -402,10 +319,11 @@ std::unique_ptr getPointwiseHeuristics( // Need to be able to parallelize, don't use break if there's not // at least an unrolled warp. - if (ceilDiv(cur_right_elem_count, max_vect_factor) <= + if (ceilDiv(cur_right_elem_count, max_vect_unroll_factor) <= at::cuda::getCurrentDeviceProperties()->warpSize) { continue; } + // If outer broadcast, or balanced broadcast: if (lhs_byte_multiple <= rhs_byte_multiple && // If right transfer size is bigger than half of L2 @@ -417,8 +335,8 @@ std::unique_ptr getPointwiseHeuristics( flip_grid_binding = false; } // Min transfer found, start setting values - bdimx = - std::min(ceilDiv(cur_right_elem_count, max_vect_factor), kThreadX); + bdimx = std::min( + ceilDiv(cur_right_elem_count, max_vect_unroll_factor), kThreadX); bdimy = 1; // Put remainder in bdimy if there's at least a wave of grid level // parallelism. @@ -427,7 +345,7 @@ std::unique_ptr getPointwiseHeuristics( } auto remainder_left = ceilDiv(cur_left_elem_count, bdimy); auto remainder_right = - ceilDiv(cur_right_elem_count, bdimx * max_vect_factor); + ceilDiv(cur_right_elem_count, bdimx * max_vect_unroll_factor); // Use this break point break_point = static_cast(break_point_i); min_total_transfer = cur_transfer_size; @@ -448,40 +366,15 @@ std::unique_ptr getPointwiseHeuristics( break_point, logical_reorder_map)); - // limit unroll factor when n_elems is small (e.g. less than 16K x 4K on H100) - // to use at least 8 waves to benefit from Thread-Level-Parallelism. Ideally, - // target wave depends on hardware, when computation latency is close to - // memory latency a smaller wave can be used. - const int64_t target_waves = 8L; - int64_t max_block_per_sm = - (int64_t)dev_prop->maxThreadsPerMultiProcessor / kThreadX; - int64_t total_blocks = break_point > 0 - ? gdim_left * gdim_right - : ceilDiv(n_elems / max_vect_factor, kThreadX); - int64_t n_waves_wo_unroll = - ceilDiv(total_blocks, max_block_per_sm * device_multiprocessor_count); - int64_t n_elems_limited_unroll = ceilDiv(n_waves_wo_unroll, target_waves); - int64_t resource_limited_unroll = scheduler_utils::safeDiv( - max_vect_unroll_factor, params->vectorization_factor); - // don't unroll if unroll is input size limited and split is not divisible - if (n_elems_limited_unroll < resource_limited_unroll) { - bool divisible_split = break_point > 0 - ? (right_elem_count % (params->vectorization_factor * bdimx) == 0) - : (n_elems % (params->vectorization_factor * kThreadX) == 0); - n_elems_limited_unroll = divisible_split ? n_elems_limited_unroll : 1; + // preserve the old heuristic where unroll is used only when vectorization is + // not used. should allow to use both unroll and vectorization together in + // heuristics tuning. + if (params->vectorization_factor == 1) { + auto total_unroll = scheduler_utils::safeDiv( + max_vect_unroll_factor, params->vectorization_factor); + params->unroll_factor_inner = total_unroll; + params->unroll_factor_outer = 1L; } - std::cout << "n_elems_limited_unroll: " << n_elems_limited_unroll - << ", resource limited unroll: " << resource_limited_unroll - << std::endl; - int64_t target_unroll_factor = - std::min(n_elems_limited_unroll, resource_limited_unroll); - - // int64_t =scheduler_utils::safeDiv( - // max_vect_unroll_factor, params->vectorization_factor); - - - params->unroll_factor_outer = 2; - params->unroll_factor_inner = 3; NVF_ERROR(right_elem_count > 0 || break_point == 0); NVF_ERROR(!(bdimy > 1 && gdim_right > 1)); @@ -504,6 +397,8 @@ std::unique_ptr getPointwiseHeuristics( << "num_elems: " << n_elems << "\n" << "elem_counts: " << elem_counts << "\n" << "max_input_dtype_size: " << max_input_dtype_size << "\n" + << "unroll_factor_inner: " << params->unroll_factor_inner + << std::endl << "unroll_factor_outer: " << params->unroll_factor_outer << std::endl << "vectorize_factor: " << params->vectorization_factor << std::endl @@ -832,24 +727,22 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); // [outer | i-remainder, TIDx, Vect] - std::cout << "reference_tv: " << reference_tv->toString() << std::endl; + if (pparams->unroll_factor_inner > 1) { + reference_tv->split(1, pparams->unroll_factor_inner); + } + // [outer| i-remainder, i-Unroll, TIDx, Vect] + if (pparams->unroll_factor_outer > 1) { reference_tv->split(0, pparams->unroll_factor_outer); - std::cout << "reference_tv: " << reference_tv->toString() << std::endl; - } - if (pparams->unroll_factor_inner > 1) { - reference_tv->split(-3, pparams->unroll_factor_inner); - std::cout << "reference_tv: " << reference_tv->toString() << std::endl; } // [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] reference_tv->split(0, 1); - // [o-remainder, Unswitch, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] + // [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect] int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2; reference_tv->reorder({{i_remainder_pos, 1}}); // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] - std::cout << "reference_tv: " << reference_tv->toString() << std::endl; reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Here we do not set axis(3)->parallelize(Unroll) because we do not want @@ -862,22 +755,23 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->vectorization_factor > 1) { vectorize_id = reference_tv->axis(-1); } - // [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect] + // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] } // Move out of the way to furthest left point reference_tv->reorder({{1, 0}}); - // [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect] + // [i-remainder, o-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] if (pparams->split_block) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy)); - // [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect] + // [i-remainder, o-remainder, TIDy, Unswitch, o-Unroll, i-Unroll, TIDx, + // Vect] if (pparams->flip_grid_binding) { - // [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect] + // [BIDy | BIDx, TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect] reference_tv->axis(1)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx, - // Vect] + // [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, o-Unroll, + // i-Unroll, TIDx, Vect] reference_tv->split(0, 65535); reference_tv->axis(1)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -886,12 +780,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { unswitch_pos = 4; } } else { - // [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect] + // [BIDx | BIDy TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect] reference_tv->axis(0)->parallelize(ParallelType::BIDx); reference_tv->axis(2)->parallelize(ParallelType::TIDy); if (pparams->split_grid_y_dim) { - // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx, - // Vect] + // [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, o-Unroll, + // i-Unroll, TIDx, Vect] reference_tv->split(1, 65535); reference_tv->axis(2)->parallelize(ParallelType::BIDy); unswitch_pos = 5; @@ -952,7 +846,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->axis(1)->parallelize(ParallelType::TIDx); reference_tv->axis(2)->parallelize(ParallelType::Unswitch); // Vectorization are propagated separately - vectorize_id = reference_tv->axis(-1); + vectorize_id = reference_tv->axis(3); //[BIDx, TIDx, Unswitch, Vectorization] // To make consistent with unrolling: @@ -980,8 +874,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // propagation process into two steps: // step 1: inline at the unswitch position for cached inputs and outputs // step 2: inline at the inner most dim for the rest of the graph - reference_tv->axis(3)->parallelize(ParallelType::TIDx); - vectorize_id = reference_tv->axis(-1); + int tidx_pos = pparams->vectorization_factor > 1 ? -2 : -1; + reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); + if (pparams->vectorization_factor > 1) { + vectorize_id = reference_tv->axis(-1); + } } unswitch_pos = 2; } @@ -1075,4 +972,4 @@ void PointWiseScheduler::schedule( schedulePointwise(fusion, pparams); } -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/scheduler/pointwise_heuristic.h b/csrc/scheduler/pointwise_heuristic.h index e587c17c21a..f52a9aa099b 100644 --- a/csrc/scheduler/pointwise_heuristic.h +++ b/csrc/scheduler/pointwise_heuristic.h @@ -88,9 +88,9 @@ class PointwiseParams : public HeuristicParams { ss << " Split y grid dim\n"; } } - ss << "Vectorize, Factor: " << vectorization_factor << "\n"; - ss << "unroll_factor_outer: " << unroll_factor_outer << "\n"; - ss << "unroll_factor_inner: " << unroll_factor_inner << "\n"; + ss << "Vectorize, Factor: " << vectorization_factor << "\n"; + ss << "unroll_factor_outer: " << unroll_factor_outer << "\n"; + ss << "unroll_factor_inner: " << unroll_factor_inner << "\n"; if (flip_grid_binding) { ss << "Flip BIDx/BIDy bindings\n"; } diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index 93ca278cc9a..a92d0a243d5 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -666,7 +666,9 @@ TEST_F(PointwiseTest, VectorizeWithExpandedBroadcast) { EXPECT_GT(getVecSizeForPointwise(fec), 1); } -TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { +using VectUnrollFactors = std::tuple; +using PointwiseParamsTest = NVFuserFixtureParamTest; +TEST_P(PointwiseParamsTest, UnrollOnTopOfVectorize) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -685,7 +687,7 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { auto t1 = at::randn({dim1}, options); std::vector runtime_inputs{t0, t1}; - // generate heuristics + // Generate heuristics SchedulerRuntimeInfo runtime_info(fusion.get(), runtime_inputs); auto scheduler_instance = SchedulerEntry::makeSchedulerInstance(SchedulerType::PointWise); @@ -693,20 +695,39 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) { scheduler_instance->computeHeuristics(fusion.get(), runtime_info); auto pparams = heuristic_params->as(); - // modify heuristics to enforce unroll on top of vectorization - pparams->vectorization_factor = 4; - pparams->unroll_factor_outer = 2; - pparams->unroll_factor_inner = 2; + // Modify heuristics to enforce unroll on top of vectorization - // schedule, compile, run, validate + // Set unroll factors from test parameters + auto [vect_factor, unroll_inner, unroll_outer] = GetParam(); + pparams->unroll_factor_inner = unroll_inner; + pparams->unroll_factor_outer = unroll_outer; + pparams->vectorization_factor = vect_factor; + + // Schedule, compile, run, validate scheduler_instance->schedule(fusion.get(), pparams); - fusion->printMath(); + fusion->print(); FusionExecutor fe; fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams); auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); const auto& lparams = fe.lastLaunchParams(); - ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor_outer); - ASSERT_EQ(lparams.gdimx(), dim1 / pparams->vectorization_factor / lparams.bdimx() / pparams->unroll_factor_outer); + ASSERT_EQ(lparams.gdimy(), dim0 / unroll_outer); + ASSERT_EQ( + lparams.gdimx(), dim1 / vect_factor / lparams.bdimx() / unroll_inner); testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__); } +INSTANTIATE_TEST_SUITE_P( + , + PointwiseParamsTest, + ::testing::Combine( + testing::Values(1, 4), // vectorization factors + testing::Values(1, 2), // inner unroll factors + testing::Values(1, 2) // outer unroll factors + ), + [](const testing::TestParamInfo& info) -> std::string { + std::stringstream ss; + ss << "vect_" << std::get<0>(info.param); + ss << "_inner_unroll_" << std::get<1>(info.param); + ss << "_outer_unroll_" << std::get<2>(info.param); + return sanitizeTestName(ss.str()); + }); } // namespace nvfuser From 12ad2e6a5d0edaae8f5ba09dee0857ccbe8c6579 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 25 Oct 2024 10:45:08 -0700 Subject: [PATCH 10/17] clean --- csrc/python_frontend/python_bindings.cpp | 6 ++++++ tests/cpp/test_pointwise.cpp | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 38b4f9aa647..a392f3950c4 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -719,6 +719,12 @@ void initNvFuserPythonBindings(PyObject* module) { [](PointwiseParams& self, int64_t unroll_factor_outer_) { self.unroll_factor_outer = unroll_factor_outer_; }); + pointwise_config.def_property( + "unroll_factor_inner", + [](PointwiseParams& self) { return self.unroll_factor_inner; }, + [](PointwiseParams& self, int64_t unroll_factor_inner_) { + self.unroll_factor_inner = unroll_factor_inner_; + }); pointwise_config.def( "__repr__", [](const PointwiseParams& self) { return self.toString(); }); diff --git a/tests/cpp/test_pointwise.cpp b/tests/cpp/test_pointwise.cpp index a92d0a243d5..552cb18f3a8 100644 --- a/tests/cpp/test_pointwise.cpp +++ b/tests/cpp/test_pointwise.cpp @@ -705,7 +705,6 @@ TEST_P(PointwiseParamsTest, UnrollOnTopOfVectorize) { // Schedule, compile, run, validate scheduler_instance->schedule(fusion.get(), pparams); - fusion->print(); FusionExecutor fe; fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams); auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams); From 94680b6f144f19aeea0bb6e06ca264b8bd48e254 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 25 Oct 2024 10:47:03 -0700 Subject: [PATCH 11/17] python --- csrc/python_frontend/python_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index a392f3950c4..f16b5ba3fe1 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -724,7 +724,7 @@ void initNvFuserPythonBindings(PyObject* module) { [](PointwiseParams& self) { return self.unroll_factor_inner; }, [](PointwiseParams& self, int64_t unroll_factor_inner_) { self.unroll_factor_inner = unroll_factor_inner_; - }); + }); pointwise_config.def( "__repr__", [](const PointwiseParams& self) { return self.toString(); }); From c5b0365e9821e452389182ce3a86886b450d3a74 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Fri, 25 Oct 2024 12:16:59 -0700 Subject: [PATCH 12/17] fix pos --- csrc/scheduler/pointwise.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index b4c215ab0a7..708d3cafe60 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -750,10 +750,17 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // propagation process into two steps: // step 1: inline at the unswitch position for cached inputs and outputs // step 2: inline at the inner most dim for the rest of the graph - int tidx_pos = pparams->vectorization_factor > 1 ? -2 : -1; + int tidx_pos = 3; + if (pparams->unroll_factor_inner > 1) { + tidx_pos++; + } + if (pparams->unroll_factor_outer > 1) { + tidx_pos++; + } reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); if (pparams->vectorization_factor > 1) { - vectorize_id = reference_tv->axis(-1); + // can't use {-1}, there may be deviceId + vectorize_id = reference_tv->axis(tidx_pos + 1); } // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] } @@ -874,10 +881,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { // propagation process into two steps: // step 1: inline at the unswitch position for cached inputs and outputs // step 2: inline at the inner most dim for the rest of the graph - int tidx_pos = pparams->vectorization_factor > 1 ? -2 : -1; + int tidx_pos = pparams->unroll_factor_inner > 1 ? 3 : 2; reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); if (pparams->vectorization_factor > 1) { - vectorize_id = reference_tv->axis(-1); + vectorize_id = reference_tv->axis(tidx_pos + 1); } } unswitch_pos = 2; From 6c22a3b3b9d35cf5d385feea7432bdc77407c9b9 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 27 Oct 2024 09:04:35 -0700 Subject: [PATCH 13/17] clean --- csrc/scheduler/pointwise_heuristic.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise_heuristic.h b/csrc/scheduler/pointwise_heuristic.h index f52a9aa099b..10d9185870e 100644 --- a/csrc/scheduler/pointwise_heuristic.h +++ b/csrc/scheduler/pointwise_heuristic.h @@ -88,7 +88,7 @@ class PointwiseParams : public HeuristicParams { ss << " Split y grid dim\n"; } } - ss << "Vectorize, Factor: " << vectorization_factor << "\n"; + ss << "vectorization_factor: " << vectorization_factor << "\n"; ss << "unroll_factor_outer: " << unroll_factor_outer << "\n"; ss << "unroll_factor_inner: " << unroll_factor_inner << "\n"; if (flip_grid_binding) { From b23cb41ff85dbb863662ec6c32192ed74097282a Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Sun, 27 Oct 2024 17:59:13 -0700 Subject: [PATCH 14/17] split even outer unroll factor == 1, should drop this commit, test code diff --- csrc/scheduler/pointwise.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 7c088afb2f7..dd83ce69277 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -712,15 +712,15 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } // [outer| i-remainder, i-Unroll, TIDx, Vect] - if (pparams->unroll_factor_outer > 1) { - reference_tv->split(0, pparams->unroll_factor_outer); - } + // TODO: Only split when unroll factor is greater than 1 + reference_tv->split(0, pparams->unroll_factor_outer); // [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] reference_tv->split(0, 1); // [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect] - int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2; + // TODO: depends on unroll_factor_outer > 1 or not + int i_remainder_pos = 3; reference_tv->reorder({{i_remainder_pos, 1}}); // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] @@ -734,9 +734,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->unroll_factor_inner > 1) { tidx_pos++; } - if (pparams->unroll_factor_outer > 1) { - tidx_pos++; - } + // TODO: only when unroll_factor_inner > 1 + tidx_pos++; reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); if (pparams->vectorization_factor > 1) { // can't use {-1}, there may be deviceId From 7e0cbfffd3fdb8c00144d6b834539e22025bf4c6 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Tue, 29 Oct 2024 12:29:12 -0700 Subject: [PATCH 15/17] Revert "split even outer unroll factor == 1, should drop this commit, test code diff" This reverts commit b23cb41ff85dbb863662ec6c32192ed74097282a. --- csrc/scheduler/pointwise.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index dd83ce69277..7c088afb2f7 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -712,15 +712,15 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } // [outer| i-remainder, i-Unroll, TIDx, Vect] - // TODO: Only split when unroll factor is greater than 1 - reference_tv->split(0, pparams->unroll_factor_outer); + if (pparams->unroll_factor_outer > 1) { + reference_tv->split(0, pparams->unroll_factor_outer); + } // [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] reference_tv->split(0, 1); // [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect] - // TODO: depends on unroll_factor_outer > 1 or not - int i_remainder_pos = 3; + int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2; reference_tv->reorder({{i_remainder_pos, 1}}); // [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect] @@ -734,8 +734,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->unroll_factor_inner > 1) { tidx_pos++; } - // TODO: only when unroll_factor_inner > 1 - tidx_pos++; + if (pparams->unroll_factor_outer > 1) { + tidx_pos++; + } reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx); if (pparams->vectorization_factor > 1) { // can't use {-1}, there may be deviceId From cb60e11a984ad42ee7b6dd21a77f7a50f7986852 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 30 Oct 2024 11:44:48 -0700 Subject: [PATCH 16/17] set unroll factor based on 1d or 2d scheduler --- csrc/scheduler/pointwise.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 7c088afb2f7..845cfaa61c1 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -371,8 +371,19 @@ std::unique_ptr getPointwiseHeuristics( if (params->vectorization_factor == 1) { auto total_unroll = scheduler_utils::safeDiv( max_vect_unroll_factor, params->vectorization_factor); - params->unroll_factor_inner = total_unroll; - params->unroll_factor_outer = 1L; + // for 1D scheduler, unroll the inner dimension + // since there is no outer dimension. + if (break_point == 0) { + params->unroll_factor_inner = total_unroll; + params->unroll_factor_outer = 1L; + } else { + // for 2D scheduler, unroll the outer dimension + // to prioritize resue across different rows, will + // be revised in heuristics tuning, e.g. unroll different + // dims based on the broadcast dimension. + params->unroll_factor_inner = 1L; + params->unroll_factor_outer = total_unroll; + } } NVF_ERROR(right_elem_count > 0 || break_point == 0); From 8616bc9412a04baf6e4221a03d480a476dbeb5e5 Mon Sep 17 00:00:00 2001 From: Liqiang Lu Date: Wed, 30 Oct 2024 11:50:48 -0700 Subject: [PATCH 17/17] add comment --- csrc/scheduler/pointwise.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 845cfaa61c1..bc7a0fb32c6 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -700,6 +700,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); reference_tv->split(0, 1); // [outer, Unswitch | i-remainder, TIDx, Vectorization] + // Here and in the following comments: + // prefix [i] represent inner dimension + // prefix [o] represent inner dimension + // [|] separates the outer and inner dimensions reference_tv->axis(1)->parallelize(ParallelType::Unswitch); reference_tv->axis(3)->parallelize(ParallelType::TIDx); // Vectorization are propagated separately @@ -726,7 +730,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { if (pparams->unroll_factor_outer > 1) { reference_tv->split(0, pparams->unroll_factor_outer); } - // [o-remainder, o-Unroll| i-remainder, i-Unroll, TIDx, Vect] + // [o-remainder, o-Unroll, | i-remainder, i-Unroll, TIDx, Vect] reference_tv->split(0, 1); // [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect]