Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend isResharding to allow DID loop split. #3421

Merged
merged 28 commits into from
Nov 25, 2024
Merged

Extend isResharding to allow DID loop split. #3421

merged 28 commits into from
Nov 25, 2024

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Nov 16, 2024

For #2563

Host latency benchmarks are neutral:

$ pytest-benchmark compare 0003 0004 --group-by=name

------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile']": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                                        Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0003_96a8efb)     170.7932 (1.0)      175.0024 (1.0)      173.0403 (1.0)      1.2632 (1.0)      173.0455 (1.0)      1.7885 (1.0)           4;0  5.7790 (1.0)          10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0004_24e1b90)     170.9341 (1.00)     175.2137 (1.00)     173.4435 (1.00)     1.4717 (1.17)     173.6018 (1.00)     1.9277 (1.08)          3;0  5.7656 (1.00)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic']": 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                                                        Min                 Max                Mean             StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0003_96a8efb)     114.1470 (1.0)      157.7900 (1.27)     124.2067 (1.05)     17.6455 (5.83)     115.8945 (1.0)      3.3760 (1.0)           2;2        8.0511 (0.96)         10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0004_24e1b90)     114.8980 (1.01)     123.9660 (1.0)      118.6434 (1.0)       3.0273 (1.0)      118.6260 (1.02)     5.5200 (1.64)          3;0        8.4286 (1.0)          10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady']": 2 tests --------------------------------------------------------------------------------
Name (time in us)                                                                      Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0003_96a8efb)     39.5550 (1.0)      50.2060 (1.13)     42.4156 (1.02)     3.4449 (2.72)     40.8575 (1.0)      4.3180 (4.40)          2;0       23.5762 (0.98)         10           1
test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0004_24e1b90)     39.9650 (1.01)     44.5350 (1.0)      41.3924 (1.0)      1.2683 (1.0)      41.1580 (1.01)     0.9820 (1.0)           2;1       24.1590 (1.0)          10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='compile']": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                              Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='compile'] (0003_96a8efb)     208.3781 (1.0)      213.6772 (1.0)      211.2823 (1.0)      1.3894 (1.0)      211.2441 (1.0)      0.9576 (1.0)           2;2  4.7330 (1.0)          10           1
test_many_segment_benchmark[host_bench_mode='compile'] (0004_24e1b90)     212.2643 (1.02)     224.4728 (1.05)     216.4410 (1.02)     3.3363 (2.40)     216.1582 (1.02)     2.8121 (2.94)          2;1  4.6202 (0.98)         10           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='dynamic']": 2 tests ------------------------------------------------------------------------------------
Name (time in us)                                                              Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='dynamic'] (0003_96a8efb)     500.2300 (1.0)      626.6390 (1.11)     548.1104 (1.03)     54.3243 (2.19)     510.5340 (1.0)      97.5850 (1.95)          3;0        1.8244 (0.97)         10           1
test_many_segment_benchmark[host_bench_mode='dynamic'] (0004_24e1b90)     500.3700 (1.00)     563.3790 (1.0)      530.3696 (1.0)      24.7773 (1.0)      530.2765 (1.04)     50.0760 (1.0)           5;0        1.8855 (1.0)          10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='steady']": 2 tests ----------------------------------------------------------------------------------
Name (time in us)                                                             Min                 Max                Mean            StdDev              Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_many_segment_benchmark[host_bench_mode='steady'] (0003_96a8efb)     174.8920 (1.0)      193.9480 (1.00)     179.0197 (1.0)      5.6300 (1.08)     176.5750 (1.0)      3.8570 (1.0)           1;1        5.5860 (1.0)          10           1
test_many_segment_benchmark[host_bench_mode='steady'] (0004_24e1b90)     175.0420 (1.00)     193.0060 (1.0)      181.0727 (1.01)     5.2312 (1.0)      180.9635 (1.02)     5.4300 (1.41)          2;1        5.5226 (0.99)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16]": 2 tests --------------------------------------------------------------------------------
Name (time in ms)                                                                            Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0003_96a8efb)     274.7113 (1.0)      278.1071 (1.0)      276.5300 (1.0)      1.1733 (1.0)      276.5168 (1.0)      2.3361 (1.73)          5;0  3.6162 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0004_24e1b90)     274.8315 (1.00)     280.0985 (1.01)     277.2161 (1.00)     1.3838 (1.18)     276.9688 (1.00)     1.3489 (1.0)           2;1  3.6073 (1.00)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2]": 2 tests ------------------------------------------------------------------------------
Name (time in ms)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0004_24e1b90)     68.5291 (1.0)      69.5794 (1.0)      69.0360 (1.0)      0.4227 (1.0)      69.0748 (1.0)      0.7587 (1.0)           3;0  14.4852 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0003_96a8efb)     70.4157 (1.03)     72.0312 (1.04)     71.4535 (1.04)     0.5711 (1.35)     71.5552 (1.04)     0.9860 (1.30)          2;0  13.9951 (0.97)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4]": 2 tests ------------------------------------------------------------------------------
Name (time in ms)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0004_24e1b90)     91.7454 (1.0)      93.9708 (1.0)      93.2835 (1.0)      0.7899 (1.0)      93.5256 (1.0)      0.8339 (1.0)           2;1  10.7200 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0003_96a8efb)     94.9294 (1.03)     97.1430 (1.03)     95.9882 (1.03)     0.8550 (1.08)     95.9445 (1.03)     1.5375 (1.84)          4;0  10.4179 (0.97)         10           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8]": 2 tests -------------------------------------------------------------------------------
Name (time in ms)                                                                           Min                 Max                Mean            StdDev              Median               IQR            Outliers     OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0004_24e1b90)     147.1260 (1.0)      148.6477 (1.0)      147.7856 (1.0)      0.5382 (1.0)      147.7964 (1.0)      0.9269 (1.0)           3;0  6.7666 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0003_96a8efb)     148.2082 (1.01)     151.0851 (1.02)     149.7864 (1.01)     0.8852 (1.64)     149.8612 (1.01)     0.9669 (1.04)          3;0  6.6762 (0.99)         10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16]": 2 tests ----------------------------------------------------------------------------------
Name (time in us)                                                                           Min                 Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0003_96a8efb)     73.6900 (1.0)      114.9290 (1.39)     86.4882 (1.13)     13.3314 (4.96)     81.0935 (1.07)     18.4040 (9.83)          2;0       11.5623 (0.89)         10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0004_24e1b90)     73.8800 (1.00)      82.7570 (1.0)      76.5963 (1.0)       2.6876 (1.0)      75.8490 (1.0)       1.8720 (1.0)           3;1       13.0555 (1.0)          10           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0004_24e1b90)     63.3200 (1.0)      72.0470 (1.0)      68.0081 (1.0)       2.8500 (1.0)      67.4285 (1.0)       4.7200 (1.0)           4;0       14.7041 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0003_96a8efb)     65.8250 (1.04)     92.5460 (1.28)     74.4402 (1.09)     10.3170 (3.62)     69.5170 (1.03)     18.0350 (3.82)          3;0       13.4336 (0.91)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0004_24e1b90)     64.0220 (1.0)      78.1180 (1.0)      68.6502 (1.0)       3.6925 (1.0)      68.1590 (1.0)       0.9820 (1.0)           2;3       14.5666 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0003_96a8efb)     67.1480 (1.05)     96.8140 (1.24)     77.4761 (1.13)     10.4931 (2.84)     73.0485 (1.07)     18.8960 (19.24)         3;0       12.9072 (0.89)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0004_24e1b90)     68.3790 (1.0)      76.8860 (1.0)      71.7751 (1.0)       2.7013 (1.0)      71.4705 (1.0)       2.7150 (1.0)           4;1       13.9324 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0003_96a8efb)     70.9040 (1.04)     98.3670 (1.28)     80.6828 (1.12)     11.6885 (4.33)     75.1170 (1.05)     23.9160 (8.81)          3;0       12.3942 (0.89)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16]": 2 tests --------------------------------------------------------------------------------
Name (time in us)                                                                          Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0003_96a8efb)     23.9860 (1.0)      33.7140 (1.11)     26.1107 (1.01)     2.8895 (1.61)     24.8575 (1.0)      1.8430 (2.42)          1;1       38.2985 (0.99)         10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0004_24e1b90)     24.7370 (1.03)     30.4480 (1.0)      25.8039 (1.0)      1.7968 (1.0)      25.0880 (1.01)     0.7610 (1.0)           1;2       38.7538 (1.0)          10           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0003_96a8efb)     23.9560 (1.0)      32.4720 (1.09)     26.1016 (1.01)     2.9284 (1.72)     24.9875 (1.0)      1.4230 (1.48)          2;2       38.3118 (0.99)         10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0004_24e1b90)     24.4160 (1.02)     29.8470 (1.0)      25.7449 (1.0)      1.7051 (1.0)      25.1225 (1.01)     0.9620 (1.0)           2;2       38.8426 (1.0)          10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0003_96a8efb)     23.3650 (1.0)      30.4380 (1.08)     25.1518 (1.0)      1.9443 (1.76)     24.6120 (1.0)      0.7120 (1.13)          1;1       39.7586 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0004_24e1b90)     24.2360 (1.04)     28.1530 (1.0)      25.2177 (1.00)     1.1044 (1.0)      24.8370 (1.01)     0.6320 (1.0)           1;1       39.6547 (1.00)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

-------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8]": 2 tests ---------------------------------------------------------------------------------
Name (time in us)                                                                         Min                Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0003_96a8efb)     24.1860 (1.0)      33.0630 (1.17)     25.4774 (1.0)      2.7047 (2.00)     24.5365 (1.0)      1.0620 (1.0)           1;1       39.2505 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0004_24e1b90)     24.2860 (1.00)     28.3440 (1.0)      25.6967 (1.01)     1.3499 (1.0)      25.1020 (1.02)     2.0040 (1.89)          3;0       38.9155 (0.99)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@wujingyue
Copy link
Collaborator Author

!test

1 similar comment
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

1 similar comment
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue requested a review from naoyam November 18, 2024 07:31
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator Author

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam this is my attempt to reimplement isResharding using IdModel so it can support DID loop split. Any early comments before I clean it?

@wujingyue
Copy link
Collaborator Author

@xwang233 FYI, H100 tests are failing apparently for the same infra failure that I tagged you about earlier.

csrc/preseg_passes/insert_reshardings.cpp Outdated Show resolved Hide resolved
csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
@wujingyue
Copy link
Collaborator Author

!test

tv5->axis(sharded_axis)->parallelize(ParallelType::DIDx);
if (is_tv0_tv5_sharded) {
tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx);
// tv3->axis(kShardedAxis) is a reduction, so don't shard it.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An incidental cleanup per

// shardAllLike, which calls parallelAllLke, tries to DID-parallelize
// reduction dimensions. For example,
//
// [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2]
//
// becomes
//
// [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Pointwise) -> [i2]
//
// This implies that the reduction result only exists on the "home" device.
// `lower_communication` can't lower such a reduction today. lowerToReduce
// is closest but it uses the output device mesh to indicate the home device.
// Also, an extra broadcast will be needed to replicate the reduction result
// to all devices for the pointwise op.
//
// Therefore, instead, we remove the DID from reduction dimensions and
// therefore reset them to Serial. This way,
// the above becomes
//
// [iDID{i1}, i2] -> (Reduce) -> [r{i1}, i2] -> (Pointwise) -> [i2]
//
// where the reduction will be lowered to an Allreduce.
//
// Alternatively, @naoyam proposed to represent an allreduce as a reduce
// followed by a broadcasting set.
//
// [iDID{i1}, i2] -> (Reduce) -> [rDID{i1}, i2] -> (Set) [i2] -> (Pointwise)
// -> [i2]
//
// This will make the semantics similar to other parallel types and therefore
// we can better leverage existing parallelization utilities. We have yet to
// pursue this because of implementation difficulty -- `lower_communication`
// would need to match the reduce-set pattern.
for (TensorView* tv : fusion->allTvs()) {
for (IterDomain* id : tv->getLoopDomain()) {
if (id->isReduction() && id->isDeviceDim()) {
id->parallelize(ParallelType::Serial);
}
}
}
. I'm fine keeping the DIDx -- it doesn't hurt.

@wujingyue wujingyue marked this pull request as ready for review November 21, 2024 00:57
@wujingyue
Copy link
Collaborator Author

!test


TensorView* rfactor = reshape(
in,
{in->axis(0)->extent(), in->axis(1)->extent(), in->axis(2)->extent()});
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam I believe there's a better way to write it. commitLeafToLogical?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what is supposed to be computed here. Looks like only the outermost dimension of rfactor is reduced. What is assumed about the input? What is the output?

Copy link
Collaborator Author

@wujingyue wujingyue Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the test so it's no longer a concern.

FWIW, a more realistic example is the reduce-scatter after a sequence-parallel MLP forward. With DID loop split, it'll look like

in: logical=[B, S, 4E], loop=[D, B, S, 4E/D]
w: logical=[E, 4E], loop=[D, E, 4E/D]
out=linear(in,w): logical=[B, S, E], loop=[D, B, S/D, E]

However, out->definition() does both linear and resharding, so it has to be decomposed before lowering.

I think we'll end up decomposing it as:

in: logical=[B, S, 4E], loop=[D, B, S, 4E/D]
in'=reshape(in, [D, B, S, 4E/D]): logical/loop=[D, B, S, 4E/D]
w: logical=[E, 4E], loop=[D, E, 4E/D]
w'=reshape(w, [D, E, 4E/D]): logical/loop=[D, E, 4E/D]
local_out=linear(in',w'): logical/loop=[D, B, S, E]
out=sum(local_out): logical=[B, S, E], loop=[D, B, S/D, E]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why reshape is used, but this is probably a topic for a future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why reshape is used, but this is probably a topic for a future PR.

Because I can't find a better way at this moment... sum (which represents ReduceScatter) has to reduce a device dimension in the logical domain. Therefore, the linear has to produce a device dimension in its logical domain, and the reshape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds like TensorView::rFactor.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasDifferentShardings looks good to me. Not yet fully checked the tests.

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

@Priya2698 and @xwang233: can you teach me how to run host-latency benchmarks so I can make sure this PR doesn't cause big regressions like Lightning-AI/lightning-thunder#229?

@Priya2698
Copy link
Collaborator

Priya2698 commented Nov 22, 2024

@Priya2698 and @xwang233: can you teach me how to run host-latency benchmarks so I can make sure this PR doesn't cause big regressions like Lightning-AI/lightning-thunder#229?

We have the following 3 benchmarks for host latency: test_many_pointwise_ops.py, test_adaptive_layernorm_host.py, and test_many_segments_host.py. They can be run using pytest similar to other benchmark files:

# Make sure you are not using NVFUSER_DISABLE=kernel_reuse to properly measure dynamic shape latency
pytest <benchmark_file>

You will get 3 types of measurement, specified by host_bench_mode=compilesteady/dynamic:
compile: Compilation through NVRTC (Host time – first time overhead): This measures the first time overhead when executing a fusion definition
steady: Cache Hit overhead: This ignores the first time overhead, and measures the host latency when executing the same input multiple times.
dynamic: Dynamic Shape overhead (not counting first time overhead): This ignores the first time overhead, and measures the host latency when executing different inputs shapes for the same fusion definition.

You can see this comment for an example on what output you will see. The numbers for dynamic may have changed due to some recent fixes, so do run on main again for comparison.
Please let me know if you need any clarifications.

I will look into renaming benchmarks to have _host in all names or moving them to a host/ folder so they are easily identifiable as well.

@wujingyue
Copy link
Collaborator Author

We have the following 3 benchmarks for host latency

Thanks -- those indeed caught a regression, e.g.,

---------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16]": 2 tests -----------------------------------------------------------------------------
Name (time in ms)                                                                          Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0001_6627725)     1.0066 (1.0)      1.1206 (1.0)      1.0464 (1.0)      0.0499 (1.06)     1.0218 (1.0)      0.1073 (1.36)          3;0  955.6723 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0002_c5c739d)     1.2108 (1.20)     1.3477 (1.20)     1.2577 (1.20)     0.0471 (1.0)      1.2406 (1.21)     0.0787 (1.0)           2;0  795.1148 (0.83)         10           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2]": 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                                                           Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0001_6627725)     239.9760 (1.0)      297.9760 (1.0)      259.3536 (1.0)      24.3398 (1.0)      245.6110 (1.0)      46.9880 (1.0)           3;0        3.8557 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0002_c5c739d)     279.7310 (1.17)     366.2150 (1.23)     310.4558 (1.20)     34.5488 (1.42)     292.1150 (1.19)     69.8930 (1.49)          3;0        3.2211 (0.84)         10           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4]": 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                                                           Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0001_6627725)     340.4960 (1.0)      479.3810 (1.0)      380.8543 (1.0)      44.8786 (1.34)     361.7620 (1.0)      61.9180 (1.0)           1;0        2.6257 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0002_c5c739d)     412.7540 (1.21)     493.0570 (1.03)     440.3725 (1.16)     33.5658 (1.0)      423.6700 (1.17)     69.2910 (1.12)          3;0        2.2708 (0.86)         10           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8]": 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                                                           Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0001_6627725)     549.8740 (1.0)      649.2330 (1.0)      582.3793 (1.0)      40.9192 (1.0)      561.0210 (1.0)      78.9800 (1.0)           3;0        1.7171 (1.0)          10           1
test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0002_c5c739d)     651.5180 (1.18)     761.0250 (1.17)     689.2533 (1.18)     46.5024 (1.14)     667.2220 (1.19)     99.8590 (1.26)          3;0        1.4508 (0.84)         10           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

wujingyue added a commit that referenced this pull request Nov 23, 2024
This is to help track down the host latency regression found in #3421.
wujingyue added a commit that referenced this pull request Nov 23, 2024
This is to help track down the host latency regression found in #3421.
wujingyue added a commit that referenced this pull request Nov 24, 2024
This is to help track down the host latency regression found in #3421.
@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 25, 2024

@wujingyue I was initially surprised to see this impacting dynamic host latency, since that path should not trigger pre-seg passes and I thought this PR only touched the pre-seg pass. The pre-seg passes are only run when we create a new FusionKernelRuntime, which does not happen if we are able to re-use a runtime for a new set of input shapes as in the dynamic benchmark.

However, I think the perf impact is because this modifies scheduler_utils::isResharding. I expected that to also only be used in compile time checks, but there is one use here, in scheduler_utils::getReductionTvs:

!isResharding(tv->definition())) {
. And this function is actually used in runtime checks and heuristics for the reduction and normalization schedulers.

@jacobhinkle
Copy link
Collaborator

BTW there is currently both nvfuser::scheduler_utils::isResharding and nvfuser::isResharding.

scheduler_utils::isResharding has this caller graph showing it's only used at compile time:
image

nvfuser::isResharding, which is what this PR is modifying, has this caller graph which includes runtime/heuristic stuff:
image

@wujingyue
Copy link
Collaborator Author

I should have synced the branch past @jacobhinkle's 3266b9d. The regression no longer happens.

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue wujingyue merged commit c4a0335 into main Nov 25, 2024
34 of 35 checks passed
@wujingyue wujingyue deleted the wjy/loop branch November 25, 2024 21:16
@wujingyue wujingyue added the enhancement New feature or request label Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants