Skip to content

Commit

Permalink
Remove unnecessary propagation that's already done by the ops API. (#…
Browse files Browse the repository at this point in the history
…3424)

Preparation for #2563
  • Loading branch information
wujingyue authored Nov 20, 2024
1 parent a6cf1bb commit 97544c3
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
permute(output_permute, {{0, sharding_axis_after_permute}});
ir_utils::replaceValInAllExprInputsAndFusionOutputs(output, new_output);

// Propagate shardings from input and manually apply sharding additions.
shardAllLike(input, {input_permute, output_permute, new_output});
output_permute->axis(0)->parallelize(shard_added_id->getParallelType());
new_output->axis(sharding_axis_after_permute)
->parallelize(shard_added_id->getParallelType());
// `output_permute` and `new_output` have inherited mesh from `input`. We
// need to change them to `output`'s mesh so communication is only
// between `input_permute` and `output_permute`.
output_permute->setDeviceMesh(output->getDeviceMesh());
new_output->setDeviceMesh(output->getDeviceMesh());
}
Expand Down

0 comments on commit 97544c3

Please sign in to comment.