-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DID parallelization of loop domain #2563
Comments
Temporarily add support for a sharded forward scaled dot product attention. Currently, we only support DID parallelization on the logical domain which requires us to split and parallelize an axis at the logical level see #2563. This is a hack until we support DID parallelization on the loop domain after which this PR can be reverted. Restrictions: 1. q,k,v inputs are manually sharded _before_ the SDPAFwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.
Temporarily add support for a sharded forward scaled dot product attention. Currently, we only support DID parallelization on the logical domain which requires us to split and parallelize an axis at the logical level see #2563. This is a hack until we support DID parallelization on the loop domain after which this PR can be reverted. Restrictions: 1. q,k,v inputs are manually sharded _before_ the SDPAFwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.
Adds temporary support for sharded backwards scaled dot product attention. Until #2563 is completed. Similar to #2565 Similar restrictions: 1. All necessary sharded inputs are manually sharded before the SDPABwdOp is created. We cannot rely on sharding propagation or sharding after the Fusion is created, because the dimension checks are called when the op is created. 2. Only the head dimension is sharded and all inputs and outputs have either a sharded head dimension or unshaded. 3. DID axis is the outermost axis. This is because during evaluation if we see 5 dimensions, it is assumed the first is the DID axis and is appropriately squeezed from the inputs and unsqueezed onto the outputs.
As we discussed in the meeting, here's a breakdown of the task:
These subtasks are roughly execution ordered. #3282 pretty much blocks everything else, so it's currently my top priority. After that's done, some other tasks (e.g. 2-4) can be parallelized as long as we limit the scope of testing. cc @naoyam, @cowanmeg and @samnordmann to take a look |
It's probably necessary to update the allocation logic of gmem tensors. Since we always allocate the whole tensor for a single GPU execution, the logical domain is used to determine the allocation size, which would need to be changed to the loop domain. This is actually just one example of loop vs logical. I could imagine there would be some other cases as well. |
Yes, I actually experimented a bit with that in this very old and out of date PR, #2181 |
I believe it has been fixed: Fuser/csrc/runtime/allocations.cpp Lines 737 to 746 in 85c22a2
|
Thanks for creating this issue so we can keep track!
@naoyam agreed! |
For #2563 Host latency benchmarks are neutral: ``` $ pytest-benchmark compare 0003 0004 --group-by=name ------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile']": 2 tests -------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0003_96a8efb) 170.7932 (1.0) 175.0024 (1.0) 173.0403 (1.0) 1.2632 (1.0) 173.0455 (1.0) 1.7885 (1.0) 4;0 5.7790 (1.0) 10 1 test_adaptive_layernorm_fwd_benchmark[host_bench_mode='compile'] (0004_24e1b90) 170.9341 (1.00) 175.2137 (1.00) 173.4435 (1.00) 1.4717 (1.17) 173.6018 (1.00) 1.9277 (1.08) 3;0 5.7656 (1.00) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ ----------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic']": 2 tests ----------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0003_96a8efb) 114.1470 (1.0) 157.7900 (1.27) 124.2067 (1.05) 17.6455 (5.83) 115.8945 (1.0) 3.3760 (1.0) 2;2 8.0511 (0.96) 10 1 test_adaptive_layernorm_fwd_benchmark[host_bench_mode='dynamic'] (0004_24e1b90) 114.8980 (1.01) 123.9660 (1.0) 118.6434 (1.0) 3.0273 (1.0) 118.6260 (1.02) 5.5200 (1.64) 3;0 8.4286 (1.0) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------------------- benchmark "test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady']": 2 tests -------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0003_96a8efb) 39.5550 (1.0) 50.2060 (1.13) 42.4156 (1.02) 3.4449 (2.72) 40.8575 (1.0) 4.3180 (4.40) 2;0 23.5762 (0.98) 10 1 test_adaptive_layernorm_fwd_benchmark[host_bench_mode='steady'] (0004_24e1b90) 39.9650 (1.01) 44.5350 (1.0) 41.3924 (1.0) 1.2683 (1.0) 41.1580 (1.01) 0.9820 (1.0) 2;1 24.1590 (1.0) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='compile']": 2 tests -------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_many_segment_benchmark[host_bench_mode='compile'] (0003_96a8efb) 208.3781 (1.0) 213.6772 (1.0) 211.2823 (1.0) 1.3894 (1.0) 211.2441 (1.0) 0.9576 (1.0) 2;2 4.7330 (1.0) 10 1 test_many_segment_benchmark[host_bench_mode='compile'] (0004_24e1b90) 212.2643 (1.02) 224.4728 (1.05) 216.4410 (1.02) 3.3363 (2.40) 216.1582 (1.02) 2.8121 (2.94) 2;1 4.6202 (0.98) 10 1 -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='dynamic']": 2 tests ------------------------------------------------------------------------------------ Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_many_segment_benchmark[host_bench_mode='dynamic'] (0003_96a8efb) 500.2300 (1.0) 626.6390 (1.11) 548.1104 (1.03) 54.3243 (2.19) 510.5340 (1.0) 97.5850 (1.95) 3;0 1.8244 (0.97) 10 1 test_many_segment_benchmark[host_bench_mode='dynamic'] (0004_24e1b90) 500.3700 (1.00) 563.3790 (1.0) 530.3696 (1.0) 24.7773 (1.0) 530.2765 (1.04) 50.0760 (1.0) 5;0 1.8855 (1.0) 10 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------- benchmark "test_many_segment_benchmark[host_bench_mode='steady']": 2 tests ---------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_many_segment_benchmark[host_bench_mode='steady'] (0003_96a8efb) 174.8920 (1.0) 193.9480 (1.00) 179.0197 (1.0) 5.6300 (1.08) 176.5750 (1.0) 3.8570 (1.0) 1;1 5.5860 (1.0) 10 1 test_many_segment_benchmark[host_bench_mode='steady'] (0004_24e1b90) 175.0420 (1.00) 193.0060 (1.0) 181.0727 (1.01) 5.2312 (1.0) 180.9635 (1.02) 5.4300 (1.41) 2;1 5.5226 (0.99) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16]": 2 tests -------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0003_96a8efb) 274.7113 (1.0) 278.1071 (1.0) 276.5300 (1.0) 1.1733 (1.0) 276.5168 (1.0) 2.3361 (1.73) 5;0 3.6162 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=16] (0004_24e1b90) 274.8315 (1.00) 280.0985 (1.01) 277.2161 (1.00) 1.3838 (1.18) 276.9688 (1.00) 1.3489 (1.0) 2;1 3.6073 (1.00) 10 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2]": 2 tests ------------------------------------------------------------------------------ Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0004_24e1b90) 68.5291 (1.0) 69.5794 (1.0) 69.0360 (1.0) 0.4227 (1.0) 69.0748 (1.0) 0.7587 (1.0) 3;0 14.4852 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=2] (0003_96a8efb) 70.4157 (1.03) 72.0312 (1.04) 71.4535 (1.04) 0.5711 (1.35) 71.5552 (1.04) 0.9860 (1.30) 2;0 13.9951 (0.97) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ ------------------------------------------------------------------------------ benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4]": 2 tests ------------------------------------------------------------------------------ Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0004_24e1b90) 91.7454 (1.0) 93.9708 (1.0) 93.2835 (1.0) 0.7899 (1.0) 93.5256 (1.0) 0.8339 (1.0) 2;1 10.7200 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=4] (0003_96a8efb) 94.9294 (1.03) 97.1430 (1.03) 95.9882 (1.03) 0.8550 (1.08) 95.9445 (1.03) 1.5375 (1.84) 4;0 10.4179 (0.97) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ -------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8]": 2 tests ------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0004_24e1b90) 147.1260 (1.0) 148.6477 (1.0) 147.7856 (1.0) 0.5382 (1.0) 147.7964 (1.0) 0.9269 (1.0) 3;0 6.7666 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='compile'-num_iters=8] (0003_96a8efb) 148.2082 (1.01) 151.0851 (1.02) 149.7864 (1.01) 0.8852 (1.64) 149.8612 (1.01) 0.9669 (1.04) 3;0 6.6762 (0.99) 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16]": 2 tests ---------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0003_96a8efb) 73.6900 (1.0) 114.9290 (1.39) 86.4882 (1.13) 13.3314 (4.96) 81.0935 (1.07) 18.4040 (9.83) 2;0 11.5623 (0.89) 10 1 test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=16] (0004_24e1b90) 73.8800 (1.00) 82.7570 (1.0) 76.5963 (1.0) 2.6876 (1.0) 75.8490 (1.0) 1.8720 (1.0) 3;1 13.0555 (1.0) 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0004_24e1b90) 63.3200 (1.0) 72.0470 (1.0) 68.0081 (1.0) 2.8500 (1.0) 67.4285 (1.0) 4.7200 (1.0) 4;0 14.7041 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=2] (0003_96a8efb) 65.8250 (1.04) 92.5460 (1.28) 74.4402 (1.09) 10.3170 (3.62) 69.5170 (1.03) 18.0350 (3.82) 3;0 13.4336 (0.91) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0004_24e1b90) 64.0220 (1.0) 78.1180 (1.0) 68.6502 (1.0) 3.6925 (1.0) 68.1590 (1.0) 0.9820 (1.0) 2;3 14.5666 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=4] (0003_96a8efb) 67.1480 (1.05) 96.8140 (1.24) 77.4761 (1.13) 10.4931 (2.84) 73.0485 (1.07) 18.8960 (19.24) 3;0 12.9072 (0.89) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ---------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0004_24e1b90) 68.3790 (1.0) 76.8860 (1.0) 71.7751 (1.0) 2.7013 (1.0) 71.4705 (1.0) 2.7150 (1.0) 4;1 13.9324 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='dynamic'-num_iters=8] (0003_96a8efb) 70.9040 (1.04) 98.3670 (1.28) 80.6828 (1.12) 11.6885 (4.33) 75.1170 (1.05) 23.9160 (8.81) 3;0 12.3942 (0.89) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- --------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16]": 2 tests -------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0003_96a8efb) 23.9860 (1.0) 33.7140 (1.11) 26.1107 (1.01) 2.8895 (1.61) 24.8575 (1.0) 1.8430 (2.42) 1;1 38.2985 (0.99) 10 1 test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=16] (0004_24e1b90) 24.7370 (1.03) 30.4480 (1.0) 25.8039 (1.0) 1.7968 (1.0) 25.0880 (1.01) 0.7610 (1.0) 1;2 38.7538 (1.0) 10 1 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0003_96a8efb) 23.9560 (1.0) 32.4720 (1.09) 26.1016 (1.01) 2.9284 (1.72) 24.9875 (1.0) 1.4230 (1.48) 2;2 38.3118 (0.99) 10 1 test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=2] (0004_24e1b90) 24.4160 (1.02) 29.8470 (1.0) 25.7449 (1.0) 1.7051 (1.0) 25.1225 (1.01) 0.9620 (1.0) 2;2 38.8426 (1.0) 10 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0003_96a8efb) 23.3650 (1.0) 30.4380 (1.08) 25.1518 (1.0) 1.9443 (1.76) 24.6120 (1.0) 0.7120 (1.13) 1;1 39.7586 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=4] (0004_24e1b90) 24.2360 (1.04) 28.1530 (1.0) 25.2177 (1.00) 1.1044 (1.0) 24.8370 (1.01) 0.6320 (1.0) 1;1 39.6547 (1.00) 10 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- -------------------------------------------------------------------------------- benchmark "test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8]": 2 tests --------------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0003_96a8efb) 24.1860 (1.0) 33.0630 (1.17) 25.4774 (1.0) 2.7047 (2.00) 24.5365 (1.0) 1.0620 (1.0) 1;1 39.2505 (1.0) 10 1 test_pointwise_ops_benchmark[host_bench_mode='steady'-num_iters=8] (0004_24e1b90) 24.2860 (1.00) 28.3440 (1.0) 25.6967 (1.01) 1.3499 (1.0) 25.1020 (1.02) 2.0040 (1.89) 3;0 38.9155 (0.99) 10 1 ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ```
…transforms (#3458) This is a spin-off from #3444. The current code assumes that logical-to-allocation has to be a permutation. This assumption won't hold any more with #2563. So this PR tries to extend eraseInputDistinctRootDomains to support more general transforms. This can happen to single-GPU, although not as common. The tests added in this PR are for single-GPU because #3444 hasn't landed. #3444 will add some multi-GPU tests.
Currently, DID parallelization is only supported on the logical domain which requires us to manually add extra axes in the logical domain when parallelizing Fusions across multiple devices.
This has caused some PRs to introduce temporary hacks to work around logical DID axes or code that needs to be modified which is tracked here.
FYR, an internal design doc from @cowanmeg
The text was updated successfully, but these errors were encountered: