Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into resize_scheduler_init…
Browse files Browse the repository at this point in the history
…ial_version
  • Loading branch information
naoyam committed Dec 11, 2024
2 parents 55b8499 + cd99c7d commit 791c85b
Show file tree
Hide file tree
Showing 10 changed files with 921 additions and 346 deletions.
33 changes: 21 additions & 12 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,27 +429,36 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
scattered_axis >= 0,
"scattered_axis is expected to be non-negative: ",
scattered_axis);
// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.

std::vector<at::Tensor> input_tensors = at::tensor_split(
input_tensor, communication->team_size(), scattered_axis);
// We could have checked the output shape as well if reduction_axis is
// available. It's not always available via
// `communication->out()->getReductionAxis()` for manually constructed host
// IRs like
// https://github.com/NVIDIA/Fuser/blob/89c47f695b296eb4ffd27984bd4c953fc3f3264b/tests/cpp/test_multidevice_overlap.cpp#L347.
assertBuffersHaveSameSize(input_tensors, {});

// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
if (scattered_axis == 0 &&
backend->getBackendName() == c10d::NCCL_BACKEND_NAME) {
return backend->_reduce_scatter_base(
output_tensor, input_tensor, {.reduceOp = communication->reduceOp()});
}
#endif
std::vector<std::vector<at::Tensor>> input_tensors(1);
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
std::vector<std::vector<at::Tensor>> input_tensors_vec({input_tensors});
std::vector<at::Tensor> output_tensor_vec({output_tensor});
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
output_tensor_vec,
input_tensors_vec,
{.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postSendRecv(
Expand Down
77 changes: 34 additions & 43 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,39 @@ std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
return sharded_ids;
}

// Returns whether a IterDomain in a TensorView is the outermost
// allocated IterDomain in the TensorView.
bool isOutermostAllocatedId(TensorView* tv, IterDomain* id) {
for (auto i : tv->getLoopDomain()) {
if (i == id) {
return true;
// Returns the position where an axis is allocated in a tv, skipping trivial
// dimensions (i.e. DID, reduction and broadcast). Returns -1 if id is not in
// tv's loop domain WAR: today we assume that the loop domain match with the
// actual allocation, but this will have to change in the future.
int64_t allocationIndex(TensorView* tv, IterDomain* id) {
int64_t index = 0;
for (auto* loop_id : tv->getLoopDomain()) {
if (loop_id == id) {
return index;
}
if (!i->isDeviceDim() && !i->isReduction() && !i->isBroadcast()) {
return false;
if (!loop_id->isDeviceDim() && !loop_id->isReduction() &&
!loop_id->isBroadcast()) {
index++;
}
}
NVF_THROW("Id", id->toString(), " is not in TensorView ", tv->toString());
return false;
return -1;
}

} // namespace

std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr) {
NVF_ERROR(
ir_utils::isTvOp(expr), "Expression must be a TvOp ", expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding expression can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding expression can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();

TensorView* producer,
TensorView* consumer) {
std::vector<IterDomain*> shard_additions;
std::vector<IterDomain*> shard_deletions;
auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
auto rootmap =
PairwiseLogicalDomainMap(producer, consumer).mapBroadcast(false);
const auto c2p_map = rootmap.mapConsumerToProducer();

for (IterDomain* out_root : output->getMaybeRootDomain()) {
for (IterDomain* out_root : consumer->getMaybeRootDomain()) {
IterDomain* in_root = c2p_map.at(out_root);
// Ignore sharded broadcast domains and
// sharded reductions on the output
// sharded reductions on the consumer
// ex. DIDx(i0) -> r(i0) or DIDx(i0) -> r(DIDx(i0))
// since they don't affect allocation.
if (in_root->isDeviceDim() && !in_root->isBroadcast() &&
Expand All @@ -93,8 +87,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
} else if (in_root->isDeviceDim() && out_root->isDeviceDim()) {
NVF_ERROR(
in_root->getParallelType() == out_root->getParallelType(),
expr->toString(),
" reshards ",
" resharding ",
in_root->toString(),
" to ",
out_root->toString(),
Expand Down Expand Up @@ -462,23 +455,21 @@ bool isInnerResharding(Expr* expr) {
ir_utils::isTvOp(expr),
"Non-tv op is not supported : ",
expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding operations can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding operations can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")

if (!shard_deletions.empty()) {
return !isOutermostAllocatedId(input, shard_deletions[0]);
} else if (!shard_additions.empty()) {
return !isOutermostAllocatedId(output, shard_additions[0]);
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto [shard_additions, shard_deletions] =
getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")
if ((!shard_deletions.empty() &&
allocationIndex(input, shard_deletions.at(0)) > 0) ||
(!shard_additions.empty() &&
allocationIndex(output, shard_additions.at(0)) > 0)) {
return true;
}
}
}
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ NVF_API bool distributedEnabled();
// TODO: Analyze loop domain for unsharded/sharded IDs and return their
// parent root IDs.
std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr);
TensorView* producer,
TensorView* consumer);

// Returns whether a TensorView has a non-reduction axis parallelized Didx
// Checks that the other non-reduction axis are not parallelized on Didx
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
expr->toString());
auto* output = expr->outputs().at(0)->as<TensorView>();
auto* input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
auto [shard_additions, shard_deletions] = getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis: ",
Expand Down
90 changes: 89 additions & 1 deletion csrc/scheduler/tools/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <device_lower/utils.h>
#include <id_model/utils.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <logical_domain_map.h>
#include <scheduler/tools/inlining.h>
#include <transform_iter.h>
#include <val_graph_visitor.h>

#include <utility>

Expand Down Expand Up @@ -193,6 +195,46 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}
return producer->nDims();
} else {
std::optional<std::unordered_set<ValGroup>> loop_path_groups = std::nullopt;
if (consumer->definition()->isA<MmaOp>()) {
// We handle MmaOp specially here since it is currently the only operation
// for which we generate code (i.e. not SdpaFwdOp or SdpaBwdOp) that has
// some output dimensions that do not map to input dimensions. For this
// case, we need to identify potential inlined pairs each ID of which is
// not mapped at all to the other TensorView (see example below).

// Get ValGroups in loop domains of producer and consumer that are
// connected to _mapped_ IterDomains in the pairwise map.
//
// Note that for MmaOp, it would be sufficient to traverse from the
// producer loop to the consumer loop and identify when _either_ the
// consumer or producer ID is not mapped. Here we are instead traversing
// from mapped domains to both roots so that we can check that _both_
// consumer and producer ID is not mapped. This is slightly safer and this
// symmetry might be handy in handling new ops that use this feature in
// the future.
std::vector<ValGroup> pairwise_mapped_groups;
for (auto [c_id, p_id] : PairwiseLogicalDomainMap(producer, consumer)
.mapConsumerToProducer()) {
pairwise_mapped_groups.push_back(inliningGraph().toGroup(c_id));
}
// We propagate toward the loop groups from both consumer and producer
std::vector<ValGroup> all_loop_groups;
for (IterDomain* id : producer->getLoopDomain()) {
all_loop_groups.push_back(inliningGraph().toGroup(id));
}
for (IterDomain* id : consumer->getLoopDomain()) {
all_loop_groups.push_back(inliningGraph().toGroup(id));
}
// getValsBetween does not require all target groups to be visited. The
// means the result contains the subset of both loop groups that we are
// looking for
std::vector<ValGroup> group_path = getValsBetween<ValGraphBFS>(
pairwise_mapped_groups, all_loop_groups, inliningGraph());
loop_path_groups =
std::unordered_set<ValGroup>(group_path.begin(), group_path.end());
}

auto consumer_it = consumer->getLoopDomain().begin();
for (const auto producer_pos : c10::irange(producer->nDims())) {
auto p_id = producer->getLoopDomain().at(producer_pos);
Expand All @@ -211,8 +253,54 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}

IterDomain* c_id = *consumer_it;

// We can inline past positions in which both producer and consumer are
// not connected to any mapped logical IterDomain pairs.
//
// For example, an MmaOp can be constructed as follows:
//
// tv0:
// root/logical: [ iS0, iS1 ]
// loop: [ iS0, bS7, iS1 ]
// tv1:
// root/logical: [ iS2, iS3 ]
// loop: [ bS8, iS2, iS3 ]
// tv2:
// root/logical/loop: [ iS4, iS5, rS6 ]
//
// iS4 maps to iS0 so when producer==tv0 we can inline past iS0. When
// producer==tv1, iS4 doesn't map to anything in tv1 and bS8 is a loop
// broadcast in that position so we inline past the first ID in that
// case also. Similarly, we inline past iS5, iS2, and bS7.
if (loop_path_groups.has_value()) {
bool p_id_connected =
loop_path_groups->count(inliningGraph().toGroup(p_id));
bool c_id_connected =
loop_path_groups->count(inliningGraph().toGroup(c_id));
NVF_ERROR(
p_id_connected ||
(consumer->definition()->isA<MmaOp>() && p_id->isBroadcast()),
"Expected unmapped producer id to be broadcast domain in MmaOp input but found ",
p_id->toString());

if (!p_id_connected && !c_id_connected) {
NVF_ERROR(
p_id->isBroadcast(),
"Unmapped producer ID must be a broadcast created in scheduling but found ",
p_id->toString());
++consumer_it;
continue;
}
}

if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) ||
!isAllowedID(c_id, consumer, best_effort, true, false, true)) {
!isAllowedID(
c_id,
consumer,
best_effort,
/*allow_reduction=*/true,
/*allow_vectorize=*/false,
/*allow_unmappable=*/true)) {
return producer_pos;
}

Expand Down
Loading

0 comments on commit 791c85b

Please sign in to comment.