Skip to content

Commit

Permalink
Remove unnecessary propagation that's already done by the ops API. (#…
Browse files Browse the repository at this point in the history
…3424)

Preparation for #2563
  • Loading branch information
wujingyue authored and jacobhinkle committed Dec 3, 2024
1 parent 6b5b7b8 commit 7da13c2
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
permute(output_permute, {{0, sharding_axis_after_permute}});
ir_utils::replaceValInAllExprInputsAndFusionOutputs(output, new_output);

// Propagate shardings from input and manually apply sharding additions.
shardAllLike(input, {input_permute, output_permute, new_output});
output_permute->axis(0)->parallelize(shard_added_id->getParallelType());
new_output->axis(sharding_axis_after_permute)
->parallelize(shard_added_id->getParallelType());
// `output_permute` and `new_output` have inherited mesh from `input`. We
// need to change them to `output`'s mesh so communication is only
// between `input_permute` and `output_permute`.
output_permute->setDeviceMesh(output->getDeviceMesh());
new_output->setDeviceMesh(output->getDeviceMesh());
}
Expand Down

0 comments on commit 7da13c2

Please sign in to comment.