-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend isResharding to allow DID loop split. #3421
Conversation
!test |
1 similar comment
!test |
!test |
1 similar comment
!test |
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam this is my attempt to reimplement isResharding using IdModel so it can support DID loop split. Any early comments before I clean it?
@xwang233 FYI, H100 tests are failing apparently for the same infra failure that I tagged you about earlier. |
!test |
tv5->axis(sharded_axis)->parallelize(ParallelType::DIDx); | ||
if (is_tv0_tv5_sharded) { | ||
tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx); | ||
// tv3->axis(kShardedAxis) is a reduction, so don't shard it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An incidental cleanup per
Fuser/csrc/preseg_passes/propagate_shardings.cpp
Lines 101 to 140 in 3266b9d
// shardAllLike, which calls parallelAllLke, tries to DID-parallelize | |
// reduction dimensions. For example, | |
// | |
// [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] | |
// | |
// becomes | |
// | |
// [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Pointwise) -> [i2] | |
// | |
// This implies that the reduction result only exists on the "home" device. | |
// `lower_communication` can't lower such a reduction today. lowerToReduce | |
// is closest but it uses the output device mesh to indicate the home device. | |
// Also, an extra broadcast will be needed to replicate the reduction result | |
// to all devices for the pointwise op. | |
// | |
// Therefore, instead, we remove the DID from reduction dimensions and | |
// therefore reset them to Serial. This way, | |
// the above becomes | |
// | |
// [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2] | |
// | |
// where the reduction will be lowered to an Allreduce. | |
// | |
// Alternatively, @naoyam proposed to represent an allreduce as a reduce | |
// followed by a broadcasting set. | |
// | |
// [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Set) [i2] -> (Pointwise) | |
// -> [i2] | |
// | |
// This will make the semantics similar to other parallel types and therefore | |
// we can better leverage existing parallelization utilities. We have yet to | |
// pursue this because of implementation difficulty -- `lower_communication` | |
// would need to match the reduce-set pattern. | |
for (TensorView* tv : fusion->allTvs()) { | |
for (IterDomain* id : tv->getLoopDomain()) { | |
if (id->isReduction() && id->isDeviceDim()) { | |
id->parallelize(ParallelType::Serial); | |
} | |
} | |
} |
!test |
tests/cpp/test_resharding.cpp
Outdated
|
||
TensorView* rfactor = reshape( | ||
in, | ||
{in->axis(0)->extent(), in->axis(1)->extent(), in->axis(2)->extent()}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam I believe there's a better way to write it. commitLeafToLogical?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what is supposed to be computed here. Looks like only the outermost dimension of rfactor
is reduced. What is assumed about the input? What is the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the test so it's no longer a concern.
FWIW, a more realistic example is the reduce-scatter after a sequence-parallel MLP forward. With DID loop split, it'll look like
in: logical=[B, S, 4E], loop=[D, B, S, 4E/D]
w: logical=[E, 4E], loop=[D, E, 4E/D]
out=linear(in,w): logical=[B, S, E], loop=[D, B, S/D, E]
However, out->definition()
does both linear and resharding, so it has to be decomposed before lowering.
I think we'll end up decomposing it as:
in: logical=[B, S, 4E], loop=[D, B, S, 4E/D]
in'=reshape(in, [D, B, S, 4E/D]): logical/loop=[D, B, S, 4E/D]
w: logical=[E, 4E], loop=[D, E, 4E/D]
w'=reshape(w, [D, E, 4E/D]): logical/loop=[D, E, 4E/D]
local_out=linear(in',w'): logical/loop=[D, B, S, E]
out=sum(local_out): logical=[B, S, E], loop=[D, B, S/D, E]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why reshape is used, but this is probably a topic for a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why reshape is used, but this is probably a topic for a future PR.
Because I can't find a better way at this moment... sum
(which represents ReduceScatter) has to reduce a device dimension in the logical domain. Therefore, the linear has to produce a device dimension in its logical domain, and the reshape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds like TensorView::rFactor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hasDifferentShardings
looks good to me. Not yet fully checked the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!test |
!test |
@Priya2698 and @xwang233: can you teach me how to run host-latency benchmarks so I can make sure this PR doesn't cause big regressions like Lightning-AI/lightning-thunder#229? |
We have the following 3 benchmarks for host latency:
You will get 3 types of measurement, specified by You can see this comment for an example on what output you will see. The numbers for dynamic may have changed due to some recent fixes, so do run on I will look into renaming benchmarks to have |
Thanks -- those indeed caught a regression, e.g.,
|
This is to help track down the host latency regression found in #3421.
This is to help track down the host latency regression found in #3421.
This is to help track down the host latency regression found in #3421.
@wujingyue I was initially surprised to see this impacting dynamic host latency, since that path should not trigger pre-seg passes and I thought this PR only touched the pre-seg pass. The pre-seg passes are only run when we create a new FusionKernelRuntime, which does not happen if we are able to re-use a runtime for a new set of input shapes as in the dynamic benchmark. However, I think the perf impact is because this modifies Fuser/csrc/scheduler/utils.cpp Line 1120 in bb05859
|
I should have synced the branch past @jacobhinkle's 3266b9d. The regression no longer happens. |
!test |
For #2563
Host latency benchmarks are neutral: