Skip to content

Commit

Permalink
add knobs control inner dim unroll and outer dim unroll in pointwise …
Browse files Browse the repository at this point in the history
…scheduler (#3275)

**What's in this PR?**
(1) Added two knobs to control unroll in inner dim and outer dim for
pointwise scheduler
(2) The original untoll knob which applies to outer dim is removed.
(3) Extended test `UnrollOnTopOfVectorize` to test 8 different
combinations of `vectorization`, `inner unroll`, and `outer unroll`.
(4) Neither `inner unroll` nor `outer unroll` is used in the heuristics.
They are always `1` unless `vectorization == 1`, in that case, `inner
unroll` may be used after all SMs are used.
(5) If `inner or outer unroll factor == 1`, we won't split out an
additional domain with size of `1`.

**Why?**
These two knobs allows more performance optimizations, e.g. unroll in
different dims based on broadcast dims.
  • Loading branch information
liqiangxl authored Nov 2, 2024
1 parent e6b285b commit c02e7ee
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 42 deletions.
7 changes: 5 additions & 2 deletions benchmarks/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ std::string toString(const PointwiseParams* pparams) {
if (pparams->vectorization_factor > 1) {
ss << "Vectorize, Factor: " << pparams->vectorization_factor << "\n";
}
if (pparams->unroll_factor > 1) {
ss << "Unroll, Factor: " << pparams->unroll_factor << "\n";
if (pparams->unroll_factor_outer > 1) {
ss << "Outer Unroll, Factor: " << pparams->unroll_factor_outer << "\n";
}
if (pparams->unroll_factor_inner > 1) {
ss << "Inner Unroll, Factor: " << pparams->unroll_factor_inner << "\n";
}
return ss.str();
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,8 @@ void defineHeuristicParamBindings(py::module& nvfuser) {
.PARAM(PointwiseParams, split_grid_y_dim)
.PARAM(PointwiseParams, flip_grid_binding)
.PARAM(PointwiseParams, vectorization_factor)
.PARAM(PointwiseParams, unroll_factor);
.PARAM(PointwiseParams, unroll_factor_inner)
.PARAM(PointwiseParams, unroll_factor_outer);

// Matmul scheduler parameters
INITHEURISTICPARAMS(MatmulParams)
Expand Down
89 changes: 66 additions & 23 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,21 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
// not used. should allow to use both unroll and vectorization together in
// heuristics tuning.
if (params->vectorization_factor == 1) {
params->unroll_factor = scheduler_utils::safeDiv(
auto total_unroll = scheduler_utils::safeDiv(
max_vect_unroll_factor, params->vectorization_factor);
// for 1D scheduler, unroll the inner dimension
// since there is no outer dimension.
if (break_point == 0) {
params->unroll_factor_inner = total_unroll;
params->unroll_factor_outer = 1L;
} else {
// for 2D scheduler, unroll the outer dimension
// to prioritize resue across different rows, will
// be revised in heuristics tuning, e.g. unroll different
// dims based on the broadcast dimension.
params->unroll_factor_inner = 1L;
params->unroll_factor_outer = total_unroll;
}
}

NVF_ERROR(right_elem_count > 0 || break_point == 0);
Expand All @@ -394,7 +407,10 @@ std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
<< "num_elems: " << n_elems << "\n"
<< "elem_counts: " << elem_counts << "\n"
<< "max_input_dtype_size: " << max_input_dtype_size << "\n"
<< "unroll_factor: " << params->unroll_factor << std::endl
<< "unroll_factor_inner: " << params->unroll_factor_inner
<< std::endl
<< "unroll_factor_outer: " << params->unroll_factor_outer
<< std::endl
<< "vectorize_factor: " << params->vectorization_factor << std::endl
<< "\n"
<< "logical_reorder_map: ";
Expand Down Expand Up @@ -677,11 +693,17 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
reference_tv->reorder({{lhs_i, 0}, {-1, 1}});

// vectorization without unroll
if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) {
if (pparams->unroll_factor_outer == 1 &&
pparams->unroll_factor_inner == 1 &&
pparams->vectorization_factor > 1) {
reference_tv->split(1, pparams->vectorization_factor);
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
reference_tv->split(0, 1);
// [outer, Unswitch | i-remainder, TIDx, Vectorization]
// Here and in the following comments:
// prefix [i] represent inner dimension
// prefix [o] represent inner dimension
// [|] separates the outer and inner dimensions
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
// Vectorization are propagated separately
Expand All @@ -700,41 +722,58 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
// [outer | i-remainder, TIDx, Vect]

reference_tv->split(0, pparams->unroll_factor);
// [o-remainder, Unroll| i-remainder, TIDx, Vect]
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(1, pparams->unroll_factor_inner);
}
// [outer| i-remainder, i-Unroll, TIDx, Vect]

if (pparams->unroll_factor_outer > 1) {
reference_tv->split(0, pparams->unroll_factor_outer);
}
// [o-remainder, o-Unroll, | i-remainder, i-Unroll, TIDx, Vect]

reference_tv->split(0, 1);
// [o-remainder, Unswitch, Unroll | i-remainder, TIDx, Vect]
// [o-remainder, Unswitch, o-Unroll | i-remainder, i-Unroll, TIDx, Vect]

reference_tv->reorder({{3, 1}});
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect]
int i_remainder_pos = pparams->unroll_factor_outer > 1 ? 3 : 2;
reference_tv->reorder({{i_remainder_pos, 1}});
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]

reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Here we do not set axis(3)->parallelize(Unroll) because we do not want
// it to be propagated. We manually unroll by splitting the inline
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(4)->parallelize(ParallelType::TIDx);
int tidx_pos = 3;
if (pparams->unroll_factor_inner > 1) {
tidx_pos++;
}
if (pparams->unroll_factor_outer > 1) {
tidx_pos++;
}
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx);
if (pparams->vectorization_factor > 1) {
vectorize_id = reference_tv->axis(5);
// can't use {-1}, there may be deviceId
vectorize_id = reference_tv->axis(tidx_pos + 1);
}
// [o-remainder, i-remainder, Unswitch, Unroll, TIDx, Vect]
// [o-remainder, i-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
}

// Move out of the way to furthest left point
reference_tv->reorder({{1, 0}});
// [i-remainder, o-remainder, Unswitch, Unroll, TIDx, Vect]
// [i-remainder, o-remainder, Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
if (pparams->split_block) {
reference_tv->split(1, NamedScalar::getParallelDim(ParallelType::TIDy));
// [i-remainder, o-remainder, TIDy, Unswitch, Unroll, TIDx, Vect]
// [i-remainder, o-remainder, TIDy, Unswitch, o-Unroll, i-Unroll, TIDx,
// Vect]
if (pparams->flip_grid_binding) {
// [BIDy | BIDx, TIDy | Unswitch, Unroll, TIDx, Vect]
// [BIDy | BIDx, TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
reference_tv->axis(1)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (pparams->split_grid_y_dim) {
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, Unroll, TIDx,
// Vect]
// [i-remainder, BIDy{65535} | BIDx, TIDy | Unswitch, o-Unroll,
// i-Unroll, TIDx, Vect]
reference_tv->split(0, 65535);
reference_tv->axis(1)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
Expand All @@ -743,12 +782,12 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
unswitch_pos = 4;
}
} else {
// [BIDx | BIDy TIDy | Unswitch, Unroll, TIDx, Vect]
// [BIDx | BIDy TIDy | Unswitch, o-Unroll, i-Unroll, TIDx, Vect]
reference_tv->axis(0)->parallelize(ParallelType::BIDx);
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
if (pparams->split_grid_y_dim) {
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, Unroll, TIDx,
// Vect]
// [BIDx | i-remainder, BIDy{65535}, TIDy | Unswitch, o-Unroll,
// i-Unroll, TIDx, Vect]
reference_tv->split(1, 65535);
reference_tv->axis(2)->parallelize(ParallelType::BIDy);
unswitch_pos = 5;
Expand Down Expand Up @@ -796,7 +835,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// unmerged...]
reference_tv->reorder({{-1, 0}});

if (pparams->unroll_factor == 1 && pparams->vectorization_factor > 1) {
if (pparams->unroll_factor_inner == 1 &&
pparams->vectorization_factor > 1) {
// Vectorize
reference_tv->split(0, pparams->vectorization_factor);
// Unswitch
Expand All @@ -822,7 +862,9 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// Threads
reference_tv->split(0, kThreadX);
// Unroll
reference_tv->split(0, pparams->unroll_factor);
if (pparams->unroll_factor_inner > 1) {
reference_tv->split(0, pparams->unroll_factor_inner);
}
// Unswitch
reference_tv->split(0, 1);

Expand All @@ -834,9 +876,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
// propagation process into two steps:
// step 1: inline at the unswitch position for cached inputs and outputs
// step 2: inline at the inner most dim for the rest of the graph
reference_tv->axis(3)->parallelize(ParallelType::TIDx);
int tidx_pos = pparams->unroll_factor_inner > 1 ? 3 : 2;
reference_tv->axis(tidx_pos)->parallelize(ParallelType::TIDx);
if (pparams->vectorization_factor > 1) {
vectorize_id = reference_tv->axis(4);
vectorize_id = reference_tv->axis(tidx_pos + 1);
}
}
unswitch_pos = 2;
Expand Down
23 changes: 14 additions & 9 deletions csrc/scheduler/pointwise_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@ class PointwiseParams : public HeuristicParams {
// Unroll on top of vectorization
// In the 2D scheduler, unroll the outer dimension to reuse loaded data across
// rows, reducing loaded bytes by the unroll factor.
int64_t unroll_factor = 1;
// Always equals 1 for 1D scheduler.
int64_t unroll_factor_outer = 1;

// In the 2D scheduler, unroll the inner dimension to reuse loaded data across
// cols, reducing loaded bytes by the unroll factor.
// Also used in 1D scheduler.
int64_t unroll_factor_inner = 1;

using HeuristicParams::HeuristicParams;

Expand All @@ -60,7 +66,8 @@ class PointwiseParams : public HeuristicParams {
other->break_point == break_point &&
other->split_block == split_block &&
other->split_grid_y_dim == split_grid_y_dim &&
other->unroll_factor == unroll_factor &&
other->unroll_factor_outer == unroll_factor_outer &&
other->unroll_factor_inner == unroll_factor_inner &&
other->flip_grid_binding == flip_grid_binding;
return attr_equal;
}
Expand All @@ -81,12 +88,9 @@ class PointwiseParams : public HeuristicParams {
ss << " Split y grid dim\n";
}
}
if (vectorization_factor > 1) {
ss << "Vectorize, Factor: " << vectorization_factor << "\n";
}
if (unroll_factor > 1) {
ss << "Unroll, Factor: " << unroll_factor << "\n";
}
ss << "vectorization_factor: " << vectorization_factor << "\n";
ss << "unroll_factor_outer: " << unroll_factor_outer << "\n";
ss << "unroll_factor_inner: " << unroll_factor_inner << "\n";
if (flip_grid_binding) {
ss << "Flip BIDx/BIDy bindings\n";
}
Expand All @@ -100,7 +104,8 @@ class PointwiseParams : public HeuristicParams {
static_cast<size_t>(break_point) << 4 ^
static_cast<size_t>(split_block) << 5 ^
static_cast<size_t>(split_grid_y_dim) << 6 ^
static_cast<size_t>(unroll_factor) << 9 ^
static_cast<size_t>(unroll_factor_outer) << 7 ^
static_cast<size_t>(unroll_factor_inner) << 9 ^
static_cast<size_t>(flip_grid_binding) << 10;
return attr_hash;
}
Expand Down
37 changes: 30 additions & 7 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,9 @@ TEST_F(PointwiseTest, VectorizeWithExpandedBroadcast) {
EXPECT_GT(getVecSizeForPointwise(fec), 1);
}

TEST_F(PointwiseTest, UnrollOnTopOfVectorize) {
using VectUnrollFactors = std::tuple<int64_t, int64_t, int64_t>;
using PointwiseParamsTest = NVFuserFixtureParamTest<VectUnrollFactors>;
TEST_P(PointwiseParamsTest, UnrollOnTopOfVectorize) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand All @@ -685,25 +687,46 @@ TEST_F(PointwiseTest, UnrollOnTopOfVectorize) {
auto t1 = at::randn({dim1}, options);
std::vector<c10::IValue> runtime_inputs{t0, t1};

// generate heuristics
// Generate heuristics
SchedulerRuntimeInfo runtime_info(fusion.get(), runtime_inputs);
auto scheduler_instance =
SchedulerEntry::makeSchedulerInstance(SchedulerType::PointWise);
auto heuristic_params =
scheduler_instance->computeHeuristics(fusion.get(), runtime_info);
auto pparams = heuristic_params->as<PointwiseParams>();

// modify heuristics to enforce unroll on top of vectorization
pparams->vectorization_factor = 4;
pparams->unroll_factor = 2;
// Modify heuristics to enforce unroll on top of vectorization

// schedule, compile, run, validate
// Set unroll factors from test parameters
auto [vect_factor, unroll_inner, unroll_outer] = GetParam();
pparams->unroll_factor_inner = unroll_inner;
pparams->unroll_factor_outer = unroll_outer;
pparams->vectorization_factor = vect_factor;

// Schedule, compile, run, validate
scheduler_instance->schedule(fusion.get(), pparams);
FusionExecutor fe;
fe.compileFusion(fusion.get(), runtime_inputs, pparams->lparams);
auto cg_outputs = fe.runFusion(runtime_inputs, pparams->lparams);
const auto& lparams = fe.lastLaunchParams();
ASSERT_EQ(lparams.gdimy(), dim0 / pparams->unroll_factor);
ASSERT_EQ(lparams.gdimy(), dim0 / unroll_outer);
ASSERT_EQ(
lparams.gdimx(), dim1 / vect_factor / lparams.bdimx() / unroll_inner);
testValidate(fusion.get(), cg_outputs, runtime_inputs, __LINE__, __FILE__);
}
INSTANTIATE_TEST_SUITE_P(
,
PointwiseParamsTest,
::testing::Combine(
testing::Values(1, 4), // vectorization factors
testing::Values(1, 2), // inner unroll factors
testing::Values(1, 2) // outer unroll factors
),
[](const testing::TestParamInfo<VectUnrollFactors>& info) -> std::string {
std::stringstream ss;
ss << "vect_" << std::get<0>(info.param);
ss << "_inner_unroll_" << std::get<1>(info.param);
ss << "_outer_unroll_" << std::get<2>(info.param);
return sanitizeTestName(ss.str());
});
} // namespace nvfuser

0 comments on commit c02e7ee

Please sign in to comment.