Skip to content

Commit

Permalink
iter grouped multiple reductions (#2332)
Browse files Browse the repository at this point in the history
when have multiple outer reductions in a fusion,  ensure all reductions are using grouped reduction.
  • Loading branch information
liqiangxl authored Jun 8, 2024
1 parent 25903d2 commit f5c8c9c
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 50 deletions.
59 changes: 9 additions & 50 deletions csrc/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,44 +323,6 @@ std::vector<int64_t> addBackBroadcasts(
return axes;
}

// Check if a reduction is effectively an allreduce.
bool isGridAllreduce(TensorView* reduction_tv) {
// Only Local tensor is converted to allreduce
if (reduction_tv->getMemoryType() != MemoryType::Local) {
return false;
}

// Collect all reduction parallel types
ParallelTypeBitmap reduction_parallel_types;
std::for_each(
reduction_tv->getLeafDomain().begin(),
reduction_tv->getLeafDomain().end(),
[&](auto id) {
if (id->isReduction() &&
isParallelTypeBlockDim(id->getParallelType())) {
reduction_parallel_types.set(id->getParallelType());
}
});

// If any of the reduction parallel types is used to parallelize
// the broadcast, it will be converted to an allreduce reduction expr
for (auto bcast_expr :
ir_utils::filterByType<BroadcastOp>(reduction_tv->uses())) {
auto bcast_tv = bcast_expr->out()->as<TensorView>();
if (std::any_of(
bcast_tv->getLeafDomain().begin(),
bcast_tv->getLeafDomain().end(),
[&](auto bcast_id) {
auto pt = bcast_id->getParallelType();
return isParallelTypeBlockDim(pt) &&
reduction_parallel_types.get(pt);
})) {
return true;
}
}
return false;
}

void multiReductionInliner(
Fusion* fusion,
TensorView* reduction_tv,
Expand Down Expand Up @@ -548,19 +510,16 @@ void propagateParallelization(
}
}
}

std::vector<TensorView*> allreduce_tvs;
std::copy_if(
reduction_tvs.begin(),
reduction_tvs.end(),
std::back_inserter(allreduce_tvs),
[&](auto tv) {
return reduction_tv != tv &&
reduction_scheduler_utils::isGridAllreduce(tv);
});
if (!allreduce_tvs.empty()) {
// Propagate group to other reduction tvs
if (use_grouped_reduction && reduction_tvs.size() > 1) {
std::vector<TensorView*> other_reduction_tvs;
std::copy_if(
reduction_tvs.begin(),
reduction_tvs.end(),
std::back_inserter(other_reduction_tvs),
[&](auto tv) { return reduction_tv != tv; });
scheduler_utils::parallelizeAllLike(
reduction_tv, -1, allreduce_tvs, {ParallelType::Group});
reduction_tv, -1, other_reduction_tvs, {ParallelType::Group});
}
}
}
Expand Down
101 changes: 101 additions & 0 deletions tests/cpp/test_gpu_outer_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2458,4 +2458,105 @@ TEST_F(OuterReductionTest, OuterReductionMagicScheduler) {
}
}

TEST_F(OuterReductionTest, IterGroupedMultipleReductions) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
auto tv1 = makeContigTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = sum(tv0, {0});
auto tv3 = add(tv0, tv1);
auto tv4 = sum(tv3, {0});
fusion.addOutput(tv2);
fusion.addOutput(tv4);

int vect = 4, unroll = 2;
int bdimx = 32, bdimy = 16;
int gdimx = 8, gdimy = 8;
int serial = 2;
int iter_dim = vect * bdimx * gdimx;
int redu_dim = unroll * bdimy * gdimy * serial;

// manually set how to schedule the fusion
auto rparams = std::make_shared<ReductionParams>();
// vectorize
rparams->vectorize_iter_dom = true;
rparams->unroll_factor_iter_dom = vect;
// use bdimx
rparams->multiple_reds_per_blk = true;
rparams->block_dim_iter_dom = ParallelType::TIDx;
// use gdimx
rparams->grid_dim_iter_dom = ParallelType::BIDx;
// use unroll
rparams->unroll_factor_inner_reduction = unroll;
// use bdimy
rparams->cross_block_inner_reduction = true;
rparams->block_dim_inner_reduction = ParallelType::TIDy;
// use gdimy
rparams->cross_grid_inner_reduction = true;
rparams->split_grid_dim_inner_reduction = true;
rparams->grid_dim_inner_reduction = ParallelType::BIDy;
// set launch para
auto lparams = LaunchParams(
gdimx,
gdimy,
LaunchParams::UNINITIALIZED_VAL,
bdimx,
bdimy,
LaunchParams::UNINITIALIZED_VAL);
rparams->lparams = lparams;
scheduleReduction(&fusion, *rparams);

// Ensure we have two iteration grouped reductions
int num_iter_grouped_reductions = 0;
const auto& reduction_tvs =
scheduler_utils::getReductionTvs(fusion_ptr.get());
for (auto tv : reduction_tvs) {
bool has_grid_reduction = false;
bool has_grouped_domain = false;
for (auto id : tv->getLeafDomain()) {
if (id->isReduction() && id->getParallelType() == ParallelType::BIDy) {
has_grid_reduction = true;
}
if (id->getParallelType() == ParallelType::Group) {
has_grouped_domain = true;
}
}
if (has_grid_reduction) {
EXPECT_TRUE(has_grouped_domain)
<< "Expect Iteration domain grouped grid reduction, tv: "
<< tv->toString();
}
if (has_grid_reduction && has_grouped_domain) {
num_iter_grouped_reductions++;
}
}
EXPECT_TRUE(num_iter_grouped_reductions == 2)
<< "Expect 2 Iteration domain grouped grid reductions, got: "
<< num_iter_grouped_reductions;

FusionExecutor fe;
std::vector<int64_t> shape({redu_dim, iter_dim});
auto options = at::TensorOptions().device(at::kCUDA, 0);
auto t0 = at::randn(shape, options);
auto t1 = at::randn(shape, options);
std::vector<c10::IValue> aten_inputs({t0, t1});

fe.compileFusion(&fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

testValidate(
&fusion,
cg_outputs,
aten_inputs,
{t0.sum(0), (t0 + t1).sum(0)},
__LINE__,
__FILE__,
"",
lparams);
}

} // namespace nvfuser

0 comments on commit f5c8c9c

Please sign in to comment.