Skip to content

Commit

Permalink
resharding detection accepts expr with multiple IOs
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Dec 6, 2024
1 parent 67127c9 commit 970d80a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 42 deletions.
69 changes: 29 additions & 40 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,37 @@ std::unordered_set<IterDomain*> getShardedIterDomains(TensorView* tv) {
return sharded_ids;
}

// Returns whether a IterDomain in a TensorView is the outermost
// allocated IterDomain in the TensorView.
bool isOutermostAllocatedId(TensorView* tv, IterDomain* id) {
// Returns the position where an axis is allocated in a tv. Returns -1 if id is
// not in tv's loop domain WAR: today we assume that the loop domain match with
// the actual allocation, but this will have to change in the future.
int64_t allocationIndex(TensorView* tv, IterDomain* id) {
int64_t index = 0;
for (auto i : tv->getLoopDomain()) {
if (i == id) {
return true;
return index;
}
if (!i->isDeviceDim() && !i->isReduction() && !i->isBroadcast()) {
return false;
index++;
}
}
NVF_THROW("Id", id->toString(), " is not in TensorView ", tv->toString());
return false;
return -1;
}

} // namespace

std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr) {
NVF_ERROR(
ir_utils::isTvOp(expr), "Expression must be a TvOp ", expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding expression can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding expression can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();

TensorView* producer,
TensorView* consumer) {
std::vector<IterDomain*> shard_additions;
std::vector<IterDomain*> shard_deletions;
auto rootmap = PairwiseLogicalDomainMap(input, output).mapBroadcast(false);
auto rootmap =
PairwiseLogicalDomainMap(producer, consumer).mapBroadcast(false);
const auto c2p_map = rootmap.mapConsumerToProducer();

for (IterDomain* out_root : output->getMaybeRootDomain()) {
for (IterDomain* out_root : consumer->getMaybeRootDomain()) {
IterDomain* in_root = c2p_map.at(out_root);
// Ignore sharded broadcast domains and
// sharded reductions on the output
// sharded reductions on the consumer
// ex. DIDx(i0) -> r(i0) or DIDx(i0) -> r(DIDx(i0))
// since they don't affect allocation.
if (in_root->isDeviceDim() && !in_root->isBroadcast() &&
Expand All @@ -93,8 +85,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
} else if (in_root->isDeviceDim() && out_root->isDeviceDim()) {
NVF_ERROR(
in_root->getParallelType() == out_root->getParallelType(),
expr->toString(),
" reshards ",
" resharding ",
in_root->toString(),
" to ",
out_root->toString(),
Expand Down Expand Up @@ -379,23 +370,21 @@ bool isInnerResharding(Expr* expr) {
ir_utils::isTvOp(expr),
"Non-tv op is not supported : ",
expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding operations can only have one output");
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding operations can have only one input");
auto output = expr->outputs().at(0)->as<TensorView>();
auto input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")

if (!shard_deletions.empty()) {
return !isOutermostAllocatedId(input, shard_deletions[0]);
} else if (!shard_additions.empty()) {
return !isOutermostAllocatedId(output, shard_additions[0]);
for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) {
for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) {
auto [shard_additions, shard_deletions] =
getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis")
if ((!shard_deletions.empty() &&
allocationIndex(input, shard_deletions.at(0)) > 0) ||
(!shard_additions.empty() &&
allocationIndex(output, shard_additions.at(0)) > 0)) {
return true;
}
}
}
return false;
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ NVF_API bool distributedEnabled();
// TODO: Analyze loop domain for unsharded/sharded IDs and return their
// parent root IDs.
std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges(
Expr* expr);
TensorView* producer,
TensorView* consumer);

// Returns whether a TensorView has a non-reduction axis parallelized Didx
// Checks that the other non-reduction axis are not parallelized on Didx
Expand Down
2 changes: 1 addition & 1 deletion csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
expr->toString());
auto* output = expr->outputs().at(0)->as<TensorView>();
auto* input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(expr);
auto [shard_additions, shard_deletions] = getShardingChanges(input, output);
NVF_ERROR(
shard_additions.size() + shard_deletions.size() <= 1,
"Resharding expr can only support one axis: ",
Expand Down

0 comments on commit 970d80a

Please sign in to comment.