Skip to content

Commit

Permalink
Sharding Detection with multiple IO (#3540)
Browse files Browse the repository at this point in the history
# What

Patch resharding detection routines in `multidevice/utils` to handle
expressions with multiple I/O

# Why

As a step toward https://jirasw.nvidia.com/browse/NVFUSER-106, the
motivation for this patch is to be able to keep the MatmulOp (and
potentially other ops in the future) as a resharding OP, which will
undergo a special HostIr lowering (bypassing `ReorderShardedAxisPass`)

---------

Co-authored-by: Jingyue Wu <[email protected]>
  • Loading branch information
samnordmann and wujingyue authored Dec 11, 2024
1 parent d178c2a commit 2749296
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 45 deletions.
77 changes: 34 additions & 43 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,39 @@ std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
return sharded_ids;
}

// Returns whether a IterDomain in a TensorView is the outermost
// allocated IterDomain in the TensorView.
bool isOutermostAllocatedId(TensorView* tv, IterDomain* id) {
for (auto i : tv->getLoopDomain()) {
if (i == id) {
return true;
// Returns the position where an axis is allocated in a tv, skipping trivial
// dimensions (i.e. DID, reduction and broadcast). Returns -1 if id is not in
// tv's loop domain WAR: today we assume that the loop domain match with the
// actual allocation, but this will have to change in the future.
int64_t allocationIndex(TensorView* tv, IterDomain* id) {
int64_t index = 0;
for (auto* loop_id : tv->getLoopDomain()) {
if (loop_id == id) {
return index;
}
if (!i->isDeviceDim() && !i->isReduction() && !i->isBroadcast()) {
return false;
if (!loop_id->isDeviceDim() && !loop_id->isReduction() &&
!loop_id->isBroadcast()) {
index++;
}
}
NVF_THROW("Id", id->toString(), " is not in TensorView ", tv->toString());
return false;
return -1;
}

} // namespace

std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr) {
NVF_ERROR(
ir_utils::isTvOp(expr), "Expression must be a TvOp ", expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding expression can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding expression can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();

TensorView* producer,
TensorView* consumer) {
std::vector<IterDomain*> shard_additions;
std::vector<IterDomain*> shard_deletions;
auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
auto rootmap =
PairwiseLogicalDomainMap(producer, consumer).mapBroadcast(false);
const auto c2p_map = rootmap.mapConsumerToProducer();

for (IterDomain* out_root : output->getMaybeRootDomain()) {
for (IterDomain* out_root : consumer->getMaybeRootDomain()) {
IterDomain* in_root = c2p_map.at(out_root);
// Ignore sharded broadcast domains and
// sharded reductions on the output
// sharded reductions on the consumer
// ex. DIDx(i0) -> r(i0) or DIDx(i0) -> r(DIDx(i0))
// since they don't affect allocation.
if (in_root->isDeviceDim() && !in_root->isBroadcast() &&
Expand All @@ -93,8 +87,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
} else if (in_root->isDeviceDim() && out_root->isDeviceDim()) {
NVF_ERROR(
in_root->getParallelType() == out_root->getParallelType(),
expr->toString(),
" reshards ",
" resharding ",
in_root->toString(),
" to ",
out_root->toString(),
Expand Down Expand Up @@ -462,23 +455,21 @@ bool isInnerResharding(Expr* expr) {
ir_utils::isTvOp(expr),
"Non-tv op is not supported : ",
expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding operations can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding operations can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")

if (!shard_deletions.empty()) {
return !isOutermostAllocatedId(input, shard_deletions[0]);
} else if (!shard_additions.empty()) {
return !isOutermostAllocatedId(output, shard_additions[0]);
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto [shard_additions, shard_deletions] =
getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")
if ((!shard_deletions.empty() &&
allocationIndex(input, shard_deletions.at(0)) > 0) ||
(!shard_additions.empty() &&
allocationIndex(output, shard_additions.at(0)) > 0)) {
return true;
}
}
}
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ NVF_API bool distributedEnabled();
// TODO: Analyze loop domain for unsharded/sharded IDs and return their
// parent root IDs.
std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr);
TensorView* producer,
TensorView* consumer);

// Returns whether a TensorView has a non-reduction axis parallelized Didx
// Checks that the other non-reduction axis are not parallelized on Didx
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
expr->toString());
auto* output = expr->outputs().at(0)->as<TensorView>();
auto* input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
auto [shard_additions, shard_deletions] = getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis: ",
Expand Down

0 comments on commit 2749296

Please sign in to comment.