From 20812afa59e99d6e6c37eb91501f5f0958ac52d0 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Tue, 29 Oct 2024 12:54:32 -0700 Subject: [PATCH 01/19] Start implementation --- tests/cpp/test_multidevice_overlap.cpp | 80 ++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index e371a53dd4d..a32f5c9575e 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -937,4 +937,84 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) { } } +class RingAllgatherOverlapTest : public MultiDeviceTest { + protected: + int64_t number_of_steps_per_ring_, number_of_rings_; + at::Tensor src_buffer_, dst_buffer_; + at::Tensor ta_reshaped_, tc_reshaped_; + + void SetUp() { + AllgatherOverlapTest::SetUp(); + if (!communicator_->is_available()) { + return; + } + + ASSERT_EQ(params.S % num_devices_, 0); + number_of_steps_per_ring_ = num_devices_; + number_of_rings_ = params.S / num_devices_; + + ta_reshaped_ = at::reshape( + ta_, + {number_of_steps_per_ring_, + number_of_rings_, + params.M / params.S, + params.K / num_devices_}); + tc_reshaped_ = + tc_.reshape({number_of_rings_, params.M / params.S, params.N}); + + std::vector buffer_sizes = { + number_of_steps_per_ring_, + number_of_rings_, + params.M / params.S, + params.N}; + src_buffer_ = at::empty(buffer_sizes, gpu_options_); + dst_buffer_ = at::empty(buffer_sizes, gpu_options_); + } +}; + +TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) { + std::vector streams = + createStreams(params.number_of_streams, my_device_index_); + + for ([[maybe_unused]] const auto& _ : + c10::irange(params.number_of_iterations)) { + + initializeIO(); + + for (auto i : c10::irange(number_of_rings_)) { + for (auto j : c10::irange(number_of_steps_per_ring_)) { + int64_t stream_index = (i + j) % streams.size(); + setCurrentCUDAStream(streams.at(stream_index)); + + // define the sliced tensors + auto slice_index = + (my_device_index_ + j + 1) % number_of_steps_per_ring_; + auto ta_j = ta_reshaped_.select(0, slice_index).select(0, i); + auto src_buffer_j = src_buffer_.select(0, j).select(0, i); + auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); + + // define the peer ranks + auto send_rank = slice_index; + auto recv_rank = + (number_of_steps_per_ring_ + my_device_index_ - (j + 1)) % + number_of_steps_per_ring_; + + // local compute + torch::matmul_out(src_buffer_j, ta_j, tb_); + // communication + std::vector src = {src_buffer_j}; + std::vector dst = {dst_buffer_j}; + + world_communicator_->startCoalescing(); + // "tags" are not supported by nccl, so set it to 0 + world_communicator_->send(src, send_rank, 0); + world_communicator_->recv(dst, recv_rank, 0); + world_communicator_->endCoalescing()->wait(); + } + } + synchronizeStreams(streams); + torch::sum_out(tc_reshaped_, dst_buffer_, 0); + } +} + } // namespace nvfuser From cd74d51be94b7e42528fe8a64bfe78a5195a063d Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Wed, 30 Oct 2024 14:57:31 -0700 Subject: [PATCH 02/19] update --- tests/cpp/test_multidevice_overlap.cpp | 106 +++++++++++++++++++------ 1 file changed, 82 insertions(+), 24 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index a32f5c9575e..94aed0a3c62 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -939,36 +939,92 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) { class RingAllgatherOverlapTest : public MultiDeviceTest { protected: + OverlapTestParams params; + + int64_t num_devices_; + int64_t my_device_index_; int64_t number_of_steps_per_ring_, number_of_rings_; - at::Tensor src_buffer_, dst_buffer_; - at::Tensor ta_reshaped_, tc_reshaped_; + std::vector all_devices_; + at::Tensor ta_unsharded_, tb_unsharded_, tc_unsharded_; + at::Tensor tb_; + // stores the backend + c10d::Backend* world_communicator_; + + // Define I/O and intermediate Tensor shapes + std::vector ta_unsharded_sizes; + std::vector tb_unsharded_sizes; + std::vector tc_unsharded_sizes; + std::vector tb_sizes; void SetUp() { - AllgatherOverlapTest::SetUp(); - if (!communicator_->is_available()) { - return; - } + MultiDeviceTest::SetUp(); + + num_devices_ = communicator_->size(); + my_device_index_ = communicator_->deviceId(); ASSERT_EQ(params.S % num_devices_, 0); number_of_steps_per_ring_ = num_devices_; number_of_rings_ = params.S / num_devices_; - ta_reshaped_ = at::reshape( - ta_, - {number_of_steps_per_ring_, - number_of_rings_, - params.M / params.S, - params.K / num_devices_}); - tc_reshaped_ = - tc_.reshape({number_of_rings_, params.M / params.S, params.N}); + // Setup the world communicators + std::vector devices(num_devices_); + std::iota(devices.begin(), devices.end(), 0); + all_devices_ = std::move(devices); + world_communicator_ = + communicator_->getBackendForTeam(all_devices_, params.backend_type); - std::vector buffer_sizes = { - number_of_steps_per_ring_, - number_of_rings_, - params.M / params.S, - params.N}; - src_buffer_ = at::empty(buffer_sizes, gpu_options_); - dst_buffer_ = at::empty(buffer_sizes, gpu_options_); + // Debug print + if (communicator_->deviceId() == 0 && debug_print) { + debug() << params << std::endl; + } + + // A(M, K) + // B(K, sharded(N)) + // C(M, N) + ta_unsharded_sizes = std::vector{params.M, params.K}; + tb_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tb_sizes = std::vector{number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tc_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.M, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + + // Set up input tensors. We create the full unsharded tensors and define the + // actual input as the shard corresponding to the current device. Having the + // full unsharded input on each rank makes it possible to compute the + // expected result locally, hence, this way of doing is convenient for + // validating data correctness. + auto cpu_options = at::TensorOptions().dtype(at::kFloat); + at::TensorOptions gpu_options = cpu_options.device(communicator_->device()); + + ta_unsharded_ = at::empty(ta_unsharded_sizes, gpu_options); + tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); + tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); + tb_ = at::empty(tb_sizes, gpu_options); + + // Debug print + if (communicator_->deviceId() == 0 && debug_print) { + debug() << "ta_unsharded_sizes()=" << ta_unsharded_.sizes() << std::endl + << "tb_unsharded_sizes()=" << tb_unsharded_.sizes() << std::endl + << "tc_unsharded_sizes()=" << tc_unsharded_.sizes() << std::endl + << "tb_.sizes()=" << tb_.sizes() << std::endl; + } + } + + // Each rank calls uniform_ and gets the same values for ta_ and tb_ because + // the random seed is initialized the same Therefore, we do not need to have + // one rank generate ta_ and tb_ and broadcast it to the rest of the ranks + void initializeIO() { + ta_unsharded_.uniform_(); + tb_unsharded_.uniform_(); + tb_.copy_(tb_unsharded_.select(1, my_device_index_)); + } + + void validate() { + // compute the expected output for data correctness validation + auto tc_unsharded_expected_ = + torch::matmul(ta_unsharded_.cpu(), tb_unsharded_); + EXPECT_TRUE( + tc_unsharded_.cpu().allclose(tc_unsharded_expected_, 1e-1, 1e-1)) + << "Unexpected results, obtained: " << tc_unsharded_ + << "expected: " << tc_unsharded_expected_; } }; @@ -976,6 +1032,9 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) std::vector streams = createStreams(params.number_of_streams, my_device_index_); + auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; + auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { @@ -988,8 +1047,8 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) // define the sliced tensors auto slice_index = - (my_device_index_ + j + 1) % number_of_steps_per_ring_; - auto ta_j = ta_reshaped_.select(0, slice_index).select(0, i); + (my_device_index_ + j) % number_of_steps_per_ring_; + auto tb_j = tb_.select(0, slice_index).select(0, i); auto src_buffer_j = src_buffer_.select(0, j).select(0, i); auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); @@ -1013,7 +1072,6 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) } } synchronizeStreams(streams); - torch::sum_out(tc_reshaped_, dst_buffer_, 0); } } From 658f0089d4a8ed70d331c1540583003e744059d0 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Wed, 30 Oct 2024 15:27:37 -0700 Subject: [PATCH 03/19] update --- tests/cpp/test_multidevice_overlap.cpp | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 94aed0a3c62..112b629586a 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1041,34 +1041,26 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) initializeIO(); for (auto i : c10::irange(number_of_rings_)) { + c10::intrusive_ptr recv_req{NULL}; for (auto j : c10::irange(number_of_steps_per_ring_)) { int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); - // define the sliced tensors auto slice_index = (my_device_index_ + j) % number_of_steps_per_ring_; auto tb_j = tb_.select(0, slice_index).select(0, i); - auto src_buffer_j = src_buffer_.select(0, j).select(0, i); - auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); - // define the peer ranks - auto send_rank = slice_index; - auto recv_rank = - (number_of_steps_per_ring_ + my_device_index_ - (j + 1)) % - number_of_steps_per_ring_; + //auto src_buffer_j = src_buffer_.select(0, j).select(0, i); + //auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); - // local compute - torch::matmul_out(src_buffer_j, ta_j, tb_); - // communication - std::vector src = {src_buffer_j}; + // recv next index std::vector dst = {dst_buffer_j}; + auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); - world_communicator_->startCoalescing(); - // "tags" are not supported by nccl, so set it to 0 - world_communicator_->send(src, send_rank, 0); - world_communicator_->recv(dst, recv_rank, 0); - world_communicator_->endCoalescing()->wait(); + // send & matmul current index + std::vector src = {src_buffer_j}; + auto send_req = world_communicator_->send(src, send_rank, 0); + torch::matmul_out(src_buffer_j, ta_, tb_j); } } synchronizeStreams(streams); From 28702bfdf1980793b0ff5c925b006c845d2527a0 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Thu, 31 Oct 2024 09:13:13 -0700 Subject: [PATCH 04/19] update --- tests/cpp/test_multidevice_overlap.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 112b629586a..2decb7737b6 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -947,6 +947,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { std::vector all_devices_; at::Tensor ta_unsharded_, tb_unsharded_, tc_unsharded_; at::Tensor tb_; + at::Tensor ring_buffer_; // stores the backend c10d::Backend* world_communicator_; @@ -955,6 +956,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { std::vector tb_unsharded_sizes; std::vector tc_unsharded_sizes; std::vector tb_sizes; + std::vector buffer_sizes; void SetUp() { MultiDeviceTest::SetUp(); @@ -985,6 +987,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { tb_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; tb_sizes = std::vector{number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; tc_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.M, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + buffer_sizes = {params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; // same as tb_sizes but without the outermost number_of_rings_ dimension TODO: cpp-esque way of initializing this based on tb_sizes? // Set up input tensors. We create the full unsharded tensors and define the // actual input as the shard corresponding to the current device. Having the @@ -997,7 +1000,8 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { ta_unsharded_ = at::empty(ta_unsharded_sizes, gpu_options); tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); - tb_ = at::empty(tb_sizes, gpu_options); + tb_ = at::empty(tb_sizes, gpu_options); + ring_buffer_ = at::empty(buffer_sizes, gpu_options_); // Debug print if (communicator_->deviceId() == 0 && debug_print) { @@ -1048,19 +1052,24 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto slice_index = (my_device_index_ + j) % number_of_steps_per_ring_; - auto tb_j = tb_.select(0, slice_index).select(0, i); + auto tb_j = tb_.select(0, i); + auto tc_j = tc_.select(0, slice_index).select(0, i); //auto src_buffer_j = src_buffer_.select(0, j).select(0, i); //auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); // recv next index - std::vector dst = {dst_buffer_j}; + std::vector dst = {tb_j}; auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); + if (recv_req) { + recv_req->wait(); + } + // send & matmul current index std::vector src = {src_buffer_j}; auto send_req = world_communicator_->send(src, send_rank, 0); - torch::matmul_out(src_buffer_j, ta_, tb_j); + torch::matmul_out(tc_j, ta_, tb_j); } } synchronizeStreams(streams); From 32c3296d009f16cdf7ed3b6277cc0a18729868a1 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Thu, 31 Oct 2024 10:39:29 -0700 Subject: [PATCH 05/19] update --- tests/cpp/test_multidevice_overlap.cpp | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 2decb7737b6..c01da3b1366 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -987,7 +987,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { tb_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; tb_sizes = std::vector{number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; tc_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.M, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - buffer_sizes = {params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; // same as tb_sizes but without the outermost number_of_rings_ dimension TODO: cpp-esque way of initializing this based on tb_sizes? + buffer_sizes = tb_sizes; // Set up input tensors. We create the full unsharded tensors and define the // actual input as the shard corresponding to the current device. Having the @@ -1001,7 +1001,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); tb_ = at::empty(tb_sizes, gpu_options); - ring_buffer_ = at::empty(buffer_sizes, gpu_options_); + ring_buffer_ = at::empty(buffer_sizes, gpu_options); // Debug print if (communicator_->deviceId() == 0 && debug_print) { @@ -1045,7 +1045,7 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) initializeIO(); for (auto i : c10::irange(number_of_rings_)) { - c10::intrusive_ptr recv_req{NULL}; + c10::intrusive_ptr recv_req = nullptr; for (auto j : c10::irange(number_of_steps_per_ring_)) { int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); @@ -1053,26 +1053,28 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto slice_index = (my_device_index_ + j) % number_of_steps_per_ring_; auto tb_j = tb_.select(0, i); - auto tc_j = tc_.select(0, slice_index).select(0, i); - - //auto src_buffer_j = src_buffer_.select(0, j).select(0, i); - //auto dst_buffer_j = dst_buffer_.select(0, j).select(0, i); + auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); // recv next index - std::vector dst = {tb_j}; + std::vector dst = {ring_buffer_}; auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); if (recv_req) { recv_req->wait(); } + // if it's the first iteration, ring_buffer_ is empty and we haven't yet taken care of our tb_j + auto sendbuf_ = recv_req ? ring_buffer_ : tb_j; + // send & matmul current index - std::vector src = {src_buffer_j}; - auto send_req = world_communicator_->send(src, send_rank, 0); - torch::matmul_out(tc_j, ta_, tb_j); + std::vector src = {ring_buffer_}; + world_communicator_->send(src, send_rank, 0); + torch::matmul_out(tc_j, ta_unsharded_, sendbuf_); + recv_req = next_recv_req; } } synchronizeStreams(streams); + validate(); } } From d5e3803ab2ef9fb7d749d980f54842d907bbfedc Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Thu, 31 Oct 2024 11:12:07 -0700 Subject: [PATCH 06/19] update --- tests/cpp/test_multidevice_overlap.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index c01da3b1366..3ec6e2448cb 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1018,7 +1018,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { void initializeIO() { ta_unsharded_.uniform_(); tb_unsharded_.uniform_(); - tb_.copy_(tb_unsharded_.select(1, my_device_index_)); + tb_.copy_(tb_unsharded_.select(0, my_device_index_)); } void validate() { @@ -1054,20 +1054,21 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) (my_device_index_ + j) % number_of_steps_per_ring_; auto tb_j = tb_.select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); + auto ring_buffer_j = ring_buffer_.select(0, i); // recv next index - std::vector dst = {ring_buffer_}; + std::vector dst = {ring_buffer_j}; auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); if (recv_req) { recv_req->wait(); } - // if it's the first iteration, ring_buffer_ is empty and we haven't yet taken care of our tb_j - auto sendbuf_ = recv_req ? ring_buffer_ : tb_j; + // if it's the first iteration, ring_buffer_j is empty and we haven't yet taken care of our tb_j + auto sendbuf_ = recv_req ? ring_buffer_j : tb_j; // send & matmul current index - std::vector src = {ring_buffer_}; + std::vector src = {ring_buffer_j}; world_communicator_->send(src, send_rank, 0); torch::matmul_out(tc_j, ta_unsharded_, sendbuf_); recv_req = next_recv_req; From 592afde7305b5b0bb04f50bc2c2488fc8c79260e Mon Sep 17 00:00:00 2001 From: nsarka Date: Thu, 31 Oct 2024 17:50:12 -0400 Subject: [PATCH 07/19] update --- tests/cpp/test_multidevice_overlap.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 3ec6e2448cb..eaffd5527c9 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -45,16 +45,16 @@ struct OverlapTestParams { int64_t M = std::pow(2, 6); int64_t K = std::pow(2, 5); int64_t N = std::pow(2, 4); - int64_t S = std::pow(2, 3); + int64_t S = std::pow(2, 2); // network backend type CommunicatorBackend backend_type = CommunicatorBackend::kNccl; // Overlap optimization parameters // fill input with new random values and repeat the operation - int64_t number_of_iterations = 4; + int64_t number_of_iterations = 1; // Change CUDA stream at each iteration in a Round-Robin fashion - int64_t number_of_streams = 3; + int64_t number_of_streams = 128; }; std::ostream& operator<<(std::ostream& out, const OverlapTestParams& params) { @@ -1025,10 +1025,11 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { // compute the expected output for data correctness validation auto tc_unsharded_expected_ = torch::matmul(ta_unsharded_.cpu(), tb_unsharded_); + if (my_device_index_ == 0) { // nick EXPECT_TRUE( tc_unsharded_.cpu().allclose(tc_unsharded_expected_, 1e-1, 1e-1)) << "Unexpected results, obtained: " << tc_unsharded_ - << "expected: " << tc_unsharded_expected_; + << "expected: " << tc_unsharded_expected_; } } }; @@ -1054,21 +1055,25 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) (my_device_index_ + j) % number_of_steps_per_ring_; auto tb_j = tb_.select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); + if (my_device_index_ == 0) { + //std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << " tc_unsharded_=" << tc_unsharded_ << std::endl; + std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << std::endl; + } auto ring_buffer_j = ring_buffer_.select(0, i); // recv next index - std::vector dst = {ring_buffer_j}; + std::vector dst = {(i % 2 == 0) ? ring_buffer_j : tb_j}; auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); if (recv_req) { + std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << " waiting on recv" << std::endl; recv_req->wait(); } - // if it's the first iteration, ring_buffer_j is empty and we haven't yet taken care of our tb_j - auto sendbuf_ = recv_req ? ring_buffer_j : tb_j; + auto sendbuf_ = (i % 2 == 0) ? tb_j : ring_buffer_j; // send & matmul current index - std::vector src = {ring_buffer_j}; + std::vector src = {sendbuf_}; world_communicator_->send(src, send_rank, 0); torch::matmul_out(tc_j, ta_unsharded_, sendbuf_); recv_req = next_recv_req; From e4427d47211e989431303d9218f01627470e6d08 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Mon, 4 Nov 2024 13:18:34 -0800 Subject: [PATCH 08/19] fix nccl deadlock --- tests/cpp/test_multidevice_overlap.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index eaffd5527c9..d42c511f483 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1039,6 +1039,11 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + bool rank_0_first = my_device_index_ == 0; // true if im rank 0 and i havent posted any communications yet + + // posting some collective to make sure nccl is initialized + //c10d::BarrierOptions barrier_opts = {all_devices_, std::chrono::milliseconds(100), my_device_index_}; + //world_communicator_->barrier(barrier_opts)->wait(); for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { @@ -1047,6 +1052,7 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) for (auto i : c10::irange(number_of_rings_)) { c10::intrusive_ptr recv_req = nullptr; + c10::intrusive_ptr next_recv_req = nullptr; for (auto j : c10::irange(number_of_steps_per_ring_)) { int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); @@ -1063,7 +1069,9 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) // recv next index std::vector dst = {(i % 2 == 0) ? ring_buffer_j : tb_j}; - auto next_recv_req = world_communicator_->recv(dst, recv_rank, 0); + if (!rank_0_first) { + next_recv_req = world_communicator_->recv(dst, recv_rank, 0); + } if (recv_req) { std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << " waiting on recv" << std::endl; @@ -1076,6 +1084,11 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) std::vector src = {sendbuf_}; world_communicator_->send(src, send_rank, 0); torch::matmul_out(tc_j, ta_unsharded_, sendbuf_); + if (rank_0_first) { + // let rank 0 post a send before his recv on the first iteration of the loop to avoid deadlock + next_recv_req = world_communicator_->recv(dst, recv_rank, 0); + rank_0_first = false; + } recv_req = next_recv_req; } } From a28d27e240986c3f17c8b5a081371f4c27f2cfff Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Mon, 4 Nov 2024 13:50:07 -0800 Subject: [PATCH 09/19] working with full B allocation, only copying the sharded portion from tb_unsharded_ to tb_ --- tests/cpp/test_multidevice_overlap.cpp | 39 +++++++++----------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index d42c511f483..84f3ff82179 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -947,7 +947,6 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { std::vector all_devices_; at::Tensor ta_unsharded_, tb_unsharded_, tc_unsharded_; at::Tensor tb_; - at::Tensor ring_buffer_; // stores the backend c10d::Backend* world_communicator_; @@ -956,7 +955,6 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { std::vector tb_unsharded_sizes; std::vector tc_unsharded_sizes; std::vector tb_sizes; - std::vector buffer_sizes; void SetUp() { MultiDeviceTest::SetUp(); @@ -985,9 +983,8 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { // C(M, N) ta_unsharded_sizes = std::vector{params.M, params.K}; tb_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - tb_sizes = std::vector{number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tb_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; tc_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.M, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - buffer_sizes = tb_sizes; // Set up input tensors. We create the full unsharded tensors and define the // actual input as the shard corresponding to the current device. Having the @@ -1001,7 +998,6 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); tb_ = at::empty(tb_sizes, gpu_options); - ring_buffer_ = at::empty(buffer_sizes, gpu_options); // Debug print if (communicator_->deviceId() == 0 && debug_print) { @@ -1018,18 +1014,18 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { void initializeIO() { ta_unsharded_.uniform_(); tb_unsharded_.uniform_(); - tb_.copy_(tb_unsharded_.select(0, my_device_index_)); + // we have allocated the full B matrix, but only copy the sharded portion + tb_.select(0, my_device_index_).copy_(tb_unsharded_.select(0, my_device_index_)); } void validate() { // compute the expected output for data correctness validation auto tc_unsharded_expected_ = torch::matmul(ta_unsharded_.cpu(), tb_unsharded_); - if (my_device_index_ == 0) { // nick EXPECT_TRUE( tc_unsharded_.cpu().allclose(tc_unsharded_expected_, 1e-1, 1e-1)) << "Unexpected results, obtained: " << tc_unsharded_ - << "expected: " << tc_unsharded_expected_; } + << "expected: " << tc_unsharded_expected_; } }; @@ -1039,11 +1035,7 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; - bool rank_0_first = my_device_index_ == 0; // true if im rank 0 and i havent posted any communications yet - - // posting some collective to make sure nccl is initialized - //c10d::BarrierOptions barrier_opts = {all_devices_, std::chrono::milliseconds(100), my_device_index_}; - //world_communicator_->barrier(barrier_opts)->wait(); + bool rank_0_first = my_device_index_ == 0; // true iff im rank 0 and i havent posted any communications yet for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { @@ -1058,32 +1050,27 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) setCurrentCUDAStream(streams.at(stream_index)); auto slice_index = - (my_device_index_ + j) % number_of_steps_per_ring_; - auto tb_j = tb_.select(0, i); + (my_device_index_ - j + number_of_steps_per_ring_) % number_of_steps_per_ring_; + auto next_slice_index = + (my_device_index_ - j - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + auto tb_j_curr_slice = tb_.select(0, slice_index).select(0, i); + auto tb_j_next_slice = tb_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); - if (my_device_index_ == 0) { - //std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << " tc_unsharded_=" << tc_unsharded_ << std::endl; - std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << std::endl; - } - auto ring_buffer_j = ring_buffer_.select(0, i); // recv next index - std::vector dst = {(i % 2 == 0) ? ring_buffer_j : tb_j}; + std::vector dst = {tb_j_next_slice}; if (!rank_0_first) { next_recv_req = world_communicator_->recv(dst, recv_rank, 0); } if (recv_req) { - std::cout << "nick: my_device_index_=" << my_device_index_ << " ring=" << i << ", ring_slice=" << slice_index << " waiting on recv" << std::endl; recv_req->wait(); } - auto sendbuf_ = (i % 2 == 0) ? tb_j : ring_buffer_j; - // send & matmul current index - std::vector src = {sendbuf_}; + std::vector src = {tb_j_curr_slice}; world_communicator_->send(src, send_rank, 0); - torch::matmul_out(tc_j, ta_unsharded_, sendbuf_); + torch::matmul_out(tc_j, ta_unsharded_, tb_j_curr_slice); if (rank_0_first) { // let rank 0 post a send before his recv on the first iteration of the loop to avoid deadlock next_recv_req = world_communicator_->recv(dst, recv_rank, 0); From 99330397761cac18402b1d44e9c9a3608c95b6f8 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Mon, 4 Nov 2024 13:54:20 -0800 Subject: [PATCH 10/19] revert back params --- tests/cpp/test_multidevice_overlap.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 84f3ff82179..0a2b576164c 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -45,16 +45,16 @@ struct OverlapTestParams { int64_t M = std::pow(2, 6); int64_t K = std::pow(2, 5); int64_t N = std::pow(2, 4); - int64_t S = std::pow(2, 2); + int64_t S = std::pow(2, 3); // network backend type CommunicatorBackend backend_type = CommunicatorBackend::kNccl; // Overlap optimization parameters // fill input with new random values and repeat the operation - int64_t number_of_iterations = 1; + int64_t number_of_iterations = 4; // Change CUDA stream at each iteration in a Round-Robin fashion - int64_t number_of_streams = 128; + int64_t number_of_streams = 3; }; std::ostream& operator<<(std::ostream& out, const OverlapTestParams& params) { From 925b9417f301d266571528a87b8163abae16da5d Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Sun, 10 Nov 2024 17:48:26 -0800 Subject: [PATCH 11/19] use coalesced groups for sendrecv --- tests/cpp/test_multidevice_overlap.cpp | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 0a2b576164c..7a2ebe8061b 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1035,7 +1035,6 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; - bool rank_0_first = my_device_index_ == 0; // true iff im rank 0 and i havent posted any communications yet for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { @@ -1043,8 +1042,7 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) initializeIO(); for (auto i : c10::irange(number_of_rings_)) { - c10::intrusive_ptr recv_req = nullptr; - c10::intrusive_ptr next_recv_req = nullptr; + c10::intrusive_ptr comms_req = nullptr; for (auto j : c10::irange(number_of_steps_per_ring_)) { int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); @@ -1057,26 +1055,18 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) auto tb_j_next_slice = tb_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); - // recv next index - std::vector dst = {tb_j_next_slice}; - if (!rank_0_first) { - next_recv_req = world_communicator_->recv(dst, recv_rank, 0); - } - - if (recv_req) { - recv_req->wait(); + if (comms_req) { + comms_req->wait(); } // send & matmul current index std::vector src = {tb_j_curr_slice}; - world_communicator_->send(src, send_rank, 0); + std::vector dst = {tb_j_next_slice}; torch::matmul_out(tc_j, ta_unsharded_, tb_j_curr_slice); - if (rank_0_first) { - // let rank 0 post a send before his recv on the first iteration of the loop to avoid deadlock - next_recv_req = world_communicator_->recv(dst, recv_rank, 0); - rank_0_first = false; - } - recv_req = next_recv_req; + world_communicator_->startCoalescing(); + world_communicator_->send(src, send_rank, 0); + world_communicator_->recv(dst, recv_rank, 0); + comms_req = world_communicator_->endCoalescing(); } } synchronizeStreams(streams); From 97b3c0a21e4415db3a512952b052359409ea0aa7 Mon Sep 17 00:00:00 2001 From: nsarka Date: Tue, 12 Nov 2024 13:10:40 -0500 Subject: [PATCH 12/19] Update tests/cpp/test_multidevice_overlap.cpp Co-authored-by: samnordmann --- tests/cpp/test_multidevice_overlap.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 7a2ebe8061b..fc247536a35 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1033,8 +1033,8 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) std::vector streams = createStreams(params.number_of_streams, my_device_index_); - auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; - auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + const auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; + const auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { From b60ec552ae35e98de9ef4bf560cf5d6a1cdb89eb Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Fri, 29 Nov 2024 12:33:27 -0800 Subject: [PATCH 13/19] Return if comm not available --- tests/cpp/test_multidevice_overlap.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index fc247536a35..f85629984bc 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -958,6 +958,9 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { void SetUp() { MultiDeviceTest::SetUp(); + if (!communicator_->is_available()) { + return; + } num_devices_ = communicator_->size(); my_device_index_ = communicator_->deviceId(); From b42451ce7a576bcaadf0f4790c01485d4c224e5e Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Fri, 29 Nov 2024 12:50:28 -0800 Subject: [PATCH 14/19] Avoid last sendrecv --- tests/cpp/test_multidevice_overlap.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index f85629984bc..c267b0bfb0d 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1060,16 +1060,19 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) if (comms_req) { comms_req->wait(); + comms_req = nullptr; } // send & matmul current index std::vector src = {tb_j_curr_slice}; std::vector dst = {tb_j_next_slice}; torch::matmul_out(tc_j, ta_unsharded_, tb_j_curr_slice); - world_communicator_->startCoalescing(); - world_communicator_->send(src, send_rank, 0); - world_communicator_->recv(dst, recv_rank, 0); - comms_req = world_communicator_->endCoalescing(); + if (j < number_of_steps_per_ring_ - 1) { + world_communicator_->startCoalescing(); + world_communicator_->send(src, send_rank, 0); + world_communicator_->recv(dst, recv_rank, 0); + comms_req = world_communicator_->endCoalescing(); + } } } synchronizeStreams(streams); From 705c33a62c7520896047247ee0d8cf99afc14f46 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Fri, 29 Nov 2024 13:35:58 -0800 Subject: [PATCH 15/19] Linter --- tests/cpp/test_multidevice_overlap.cpp | 40 ++++++++++++++++++-------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index c267b0bfb0d..bf8918657b7 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -985,9 +985,21 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { // B(K, sharded(N)) // C(M, N) ta_unsharded_sizes = std::vector{params.M, params.K}; - tb_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - tb_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.K, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - tc_unsharded_sizes = std::vector{number_of_steps_per_ring_, number_of_rings_, params.M, params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tb_unsharded_sizes = std::vector{ + number_of_steps_per_ring_, + number_of_rings_, + params.K, + params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tb_sizes = std::vector{ + number_of_steps_per_ring_, + number_of_rings_, + params.K, + params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + tc_unsharded_sizes = std::vector{ + number_of_steps_per_ring_, + number_of_rings_, + params.M, + params.N / (number_of_steps_per_ring_ * number_of_rings_)}; // Set up input tensors. We create the full unsharded tensors and define the // actual input as the shard corresponding to the current device. Having the @@ -1000,14 +1012,14 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { ta_unsharded_ = at::empty(ta_unsharded_sizes, gpu_options); tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); - tb_ = at::empty(tb_sizes, gpu_options); + tb_ = at::empty(tb_sizes, gpu_options); // Debug print if (communicator_->deviceId() == 0 && debug_print) { debug() << "ta_unsharded_sizes()=" << ta_unsharded_.sizes() << std::endl << "tb_unsharded_sizes()=" << tb_unsharded_.sizes() << std::endl << "tc_unsharded_sizes()=" << tc_unsharded_.sizes() << std::endl - << "tb_.sizes()=" << tb_.sizes() << std::endl; + << "tb_.sizes()=" << tb_.sizes() << std::endl; } } @@ -1018,7 +1030,8 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { ta_unsharded_.uniform_(); tb_unsharded_.uniform_(); // we have allocated the full B matrix, but only copy the sharded portion - tb_.select(0, my_device_index_).copy_(tb_unsharded_.select(0, my_device_index_)); + tb_.select(0, my_device_index_) + .copy_(tb_unsharded_.select(0, my_device_index_)); } void validate() { @@ -1032,16 +1045,18 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { } }; -TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) { +TEST_F( + RingAllgatherOverlapTest, + RingAllgatherBasedPipeliningATenImplementation) { std::vector streams = createStreams(params.number_of_streams, my_device_index_); const auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_; - const auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + const auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) % + number_of_steps_per_ring_; for ([[maybe_unused]] const auto& _ : c10::irange(params.number_of_iterations)) { - initializeIO(); for (auto i : c10::irange(number_of_rings_)) { @@ -1050,10 +1065,11 @@ TEST_F(RingAllgatherOverlapTest, RingAllgatherBasedPipeliningATenImplementation) int64_t stream_index = (i + j) % streams.size(); setCurrentCUDAStream(streams.at(stream_index)); - auto slice_index = - (my_device_index_ - j + number_of_steps_per_ring_) % number_of_steps_per_ring_; + auto slice_index = (my_device_index_ - j + number_of_steps_per_ring_) % + number_of_steps_per_ring_; auto next_slice_index = - (my_device_index_ - j - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; + (my_device_index_ - j - 1 + number_of_steps_per_ring_) % + number_of_steps_per_ring_; auto tb_j_curr_slice = tb_.select(0, slice_index).select(0, i); auto tb_j_next_slice = tb_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); From 04bb5cc8f233599948b2cd957f9c5ca39c9a9e22 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Thu, 5 Dec 2024 11:24:28 -0800 Subject: [PATCH 16/19] Transpose to match coll based overlap test --- tests/cpp/test_multidevice_overlap.cpp | 56 ++++++++++++++------------ 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index bf8918657b7..a93be942d82 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -946,7 +946,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { int64_t number_of_steps_per_ring_, number_of_rings_; std::vector all_devices_; at::Tensor ta_unsharded_, tb_unsharded_, tc_unsharded_; - at::Tensor tb_; + at::Tensor ta_; // stores the backend c10d::Backend* world_communicator_; @@ -954,7 +954,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { std::vector ta_unsharded_sizes; std::vector tb_unsharded_sizes; std::vector tc_unsharded_sizes; - std::vector tb_sizes; + std::vector ta_sizes; void SetUp() { MultiDeviceTest::SetUp(); @@ -981,25 +981,29 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { debug() << params << std::endl; } - // A(M, K) - // B(K, sharded(N)) + // A(sharded(M), K) + // B(K, N) // C(M, N) - ta_unsharded_sizes = std::vector{params.M, params.K}; - tb_unsharded_sizes = std::vector{ + // We use a full allocation of A on each device for algorithm simplicity + // TODO: use only 2 buffers instead of a full allocation + ta_unsharded_sizes = std::vector{ number_of_steps_per_ring_, number_of_rings_, - params.K, - params.N / (number_of_steps_per_ring_ * number_of_rings_)}; - tb_sizes = std::vector{ + params.M / (number_of_steps_per_ring_ * number_of_rings_), + params.K}; + ta_sizes = std::vector{ number_of_steps_per_ring_, number_of_rings_, + params.M / (number_of_steps_per_ring_ * number_of_rings_), + params.K}; + tb_unsharded_sizes = std::vector{ params.K, - params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + params.N}; tc_unsharded_sizes = std::vector{ number_of_steps_per_ring_, number_of_rings_, - params.M, - params.N / (number_of_steps_per_ring_ * number_of_rings_)}; + params.M / (number_of_steps_per_ring_ * number_of_rings_), + params.N}; // Set up input tensors. We create the full unsharded tensors and define the // actual input as the shard corresponding to the current device. Having the @@ -1009,35 +1013,35 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { auto cpu_options = at::TensorOptions().dtype(at::kFloat); at::TensorOptions gpu_options = cpu_options.device(communicator_->device()); - ta_unsharded_ = at::empty(ta_unsharded_sizes, gpu_options); - tb_unsharded_ = at::empty(tb_unsharded_sizes, cpu_options); + ta_unsharded_ = at::empty(ta_unsharded_sizes, cpu_options); + tb_unsharded_ = at::empty(tb_unsharded_sizes, gpu_options); tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options); - tb_ = at::empty(tb_sizes, gpu_options); + ta_ = at::empty(ta_sizes, gpu_options); // Debug print if (communicator_->deviceId() == 0 && debug_print) { debug() << "ta_unsharded_sizes()=" << ta_unsharded_.sizes() << std::endl << "tb_unsharded_sizes()=" << tb_unsharded_.sizes() << std::endl << "tc_unsharded_sizes()=" << tc_unsharded_.sizes() << std::endl - << "tb_.sizes()=" << tb_.sizes() << std::endl; + << "ta_.sizes()=" << ta_.sizes() << std::endl; } } // Each rank calls uniform_ and gets the same values for ta_ and tb_ because // the random seed is initialized the same Therefore, we do not need to have - // one rank generate ta_ and tb_ and broadcast it to the rest of the ranks + // one rank generate A and B and broadcast it to the rest of the ranks void initializeIO() { ta_unsharded_.uniform_(); tb_unsharded_.uniform_(); - // we have allocated the full B matrix, but only copy the sharded portion - tb_.select(0, my_device_index_) - .copy_(tb_unsharded_.select(0, my_device_index_)); + // we have allocated the full A matrix, but only copy the sharded portion + ta_.select(0, my_device_index_) + .copy_(ta_unsharded_.select(0, my_device_index_)); } void validate() { // compute the expected output for data correctness validation auto tc_unsharded_expected_ = - torch::matmul(ta_unsharded_.cpu(), tb_unsharded_); + torch::matmul(ta_unsharded_, tb_unsharded_.cpu()); EXPECT_TRUE( tc_unsharded_.cpu().allclose(tc_unsharded_expected_, 1e-1, 1e-1)) << "Unexpected results, obtained: " << tc_unsharded_ @@ -1070,8 +1074,8 @@ TEST_F( auto next_slice_index = (my_device_index_ - j - 1 + number_of_steps_per_ring_) % number_of_steps_per_ring_; - auto tb_j_curr_slice = tb_.select(0, slice_index).select(0, i); - auto tb_j_next_slice = tb_.select(0, next_slice_index).select(0, i); + auto ta_j_curr_slice = ta_.select(0, slice_index).select(0, i); + auto ta_j_next_slice = ta_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); if (comms_req) { @@ -1080,9 +1084,9 @@ TEST_F( } // send & matmul current index - std::vector src = {tb_j_curr_slice}; - std::vector dst = {tb_j_next_slice}; - torch::matmul_out(tc_j, ta_unsharded_, tb_j_curr_slice); + std::vector src = {ta_j_curr_slice}; + std::vector dst = {ta_j_next_slice}; + torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_); if (j < number_of_steps_per_ring_ - 1) { world_communicator_->startCoalescing(); world_communicator_->send(src, send_rank, 0); From e32e9cf8e791109f71c9fcc6e0f0c6bcb3fc5a26 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Thu, 5 Dec 2024 12:15:40 -0800 Subject: [PATCH 17/19] linter --- tests/cpp/test_multidevice_overlap.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index a93be942d82..fffdd269908 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -996,9 +996,7 @@ class RingAllgatherOverlapTest : public MultiDeviceTest { number_of_rings_, params.M / (number_of_steps_per_ring_ * number_of_rings_), params.K}; - tb_unsharded_sizes = std::vector{ - params.K, - params.N}; + tb_unsharded_sizes = std::vector{params.K, params.N}; tc_unsharded_sizes = std::vector{ number_of_steps_per_ring_, number_of_rings_, From 04ee267e28d151a9c11599777ec4b1cec19d319c Mon Sep 17 00:00:00 2001 From: nsarka Date: Fri, 6 Dec 2024 15:13:42 -0500 Subject: [PATCH 18/19] Update tests/cpp/test_multidevice_overlap.cpp Co-authored-by: Jingyue Wu --- tests/cpp/test_multidevice_overlap.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index fffdd269908..6ed8fbec4de 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1076,7 +1076,7 @@ TEST_F( auto ta_j_next_slice = ta_.select(0, next_slice_index).select(0, i); auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i); - if (comms_req) { + if (comms_req != nullptr) { comms_req->wait(); comms_req = nullptr; } From 06aae9befe846ef350df4be57cf4119959704780 Mon Sep 17 00:00:00 2001 From: Nicholas Sarkauskas Date: Tue, 10 Dec 2024 08:18:01 -0800 Subject: [PATCH 19/19] Post gemm after comm --- tests/cpp/test_multidevice_overlap.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 6ed8fbec4de..c93296aaa27 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1084,13 +1084,13 @@ TEST_F( // send & matmul current index std::vector src = {ta_j_curr_slice}; std::vector dst = {ta_j_next_slice}; - torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_); if (j < number_of_steps_per_ring_ - 1) { world_communicator_->startCoalescing(); world_communicator_->send(src, send_rank, 0); world_communicator_->recv(dst, recv_rank, 0); comms_req = world_communicator_->endCoalescing(); } + torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_); } } synchronizeStreams(streams);