Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ring-based decomposition for Allgather+GEMM overlap ATen implementation #3392

Merged
merged 19 commits into from
Dec 13, 2024
161 changes: 161 additions & 0 deletions tests/cpp/test_multidevice_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,4 +937,165 @@ TEST_F(AllgatherOverlapTest, AllgatherBasedPipeliningHostIrImplementation) {
}
}

class RingAllgatherOverlapTest : public MultiDeviceTest {
protected:
OverlapTestParams params;

int64_t num_devices_;
int64_t my_device_index_;
int64_t number_of_steps_per_ring_, number_of_rings_;
std::vector<int64_t> all_devices_;
at::Tensor ta_unsharded_, tb_unsharded_, tc_unsharded_;
at::Tensor ta_;
// stores the backend
c10d::Backend* world_communicator_;

// Define I/O and intermediate Tensor shapes
std::vector<int64_t> ta_unsharded_sizes;
std::vector<int64_t> tb_unsharded_sizes;
std::vector<int64_t> tc_unsharded_sizes;
std::vector<int64_t> ta_sizes;

void SetUp() {
MultiDeviceTest::SetUp();
if (!communicator_->is_available()) {
return;
}

nsarka marked this conversation as resolved.
Show resolved Hide resolved
num_devices_ = communicator_->size();
my_device_index_ = communicator_->deviceId();

ASSERT_EQ(params.S % num_devices_, 0);
number_of_steps_per_ring_ = num_devices_;
number_of_rings_ = params.S / num_devices_;

// Setup the world communicators
std::vector<int64_t> devices(num_devices_);
std::iota(devices.begin(), devices.end(), 0);
all_devices_ = std::move(devices);
world_communicator_ =
communicator_->getBackendForTeam(all_devices_, params.backend_type);

// Debug print
if (communicator_->deviceId() == 0 && debug_print) {
debug() << params << std::endl;
}

// A(sharded(M), K)
// B(K, N)
// C(M, N)
// We use a full allocation of A on each device for algorithm simplicity
// TODO: use only 2 buffers instead of a full allocation
ta_unsharded_sizes = std::vector<int64_t>{
number_of_steps_per_ring_,
number_of_rings_,
params.M / (number_of_steps_per_ring_ * number_of_rings_),
params.K};
nsarka marked this conversation as resolved.
Show resolved Hide resolved
ta_sizes = std::vector<int64_t>{
number_of_steps_per_ring_,
number_of_rings_,
params.M / (number_of_steps_per_ring_ * number_of_rings_),
params.K};
tb_unsharded_sizes = std::vector<int64_t>{params.K, params.N};
tc_unsharded_sizes = std::vector<int64_t>{
number_of_steps_per_ring_,
number_of_rings_,
params.M / (number_of_steps_per_ring_ * number_of_rings_),
params.N};

// Set up input tensors. We create the full unsharded tensors and define the
// actual input as the shard corresponding to the current device. Having the
// full unsharded input on each rank makes it possible to compute the
// expected result locally, hence, this way of doing is convenient for
// validating data correctness.
auto cpu_options = at::TensorOptions().dtype(at::kFloat);
at::TensorOptions gpu_options = cpu_options.device(communicator_->device());

ta_unsharded_ = at::empty(ta_unsharded_sizes, cpu_options);
tb_unsharded_ = at::empty(tb_unsharded_sizes, gpu_options);
tc_unsharded_ = at::empty(tc_unsharded_sizes, gpu_options);
ta_ = at::empty(ta_sizes, gpu_options);

// Debug print
if (communicator_->deviceId() == 0 && debug_print) {
debug() << "ta_unsharded_sizes()=" << ta_unsharded_.sizes() << std::endl
<< "tb_unsharded_sizes()=" << tb_unsharded_.sizes() << std::endl
<< "tc_unsharded_sizes()=" << tc_unsharded_.sizes() << std::endl
<< "ta_.sizes()=" << ta_.sizes() << std::endl;
}
}

// Each rank calls uniform_ and gets the same values for ta_ and tb_ because
// the random seed is initialized the same Therefore, we do not need to have
// one rank generate A and B and broadcast it to the rest of the ranks
void initializeIO() {
ta_unsharded_.uniform_();
tb_unsharded_.uniform_();
// we have allocated the full A matrix, but only copy the sharded portion
ta_.select(0, my_device_index_)
.copy_(ta_unsharded_.select(0, my_device_index_));
}

void validate() {
// compute the expected output for data correctness validation
auto tc_unsharded_expected_ =
torch::matmul(ta_unsharded_, tb_unsharded_.cpu());
EXPECT_TRUE(
tc_unsharded_.cpu().allclose(tc_unsharded_expected_, 1e-1, 1e-1))
<< "Unexpected results, obtained: " << tc_unsharded_
<< "expected: " << tc_unsharded_expected_;
}
};

TEST_F(
RingAllgatherOverlapTest,
RingAllgatherBasedPipeliningATenImplementation) {
std::vector<c10::cuda::CUDAStream> streams =
createStreams(params.number_of_streams, my_device_index_);

const auto send_rank = (my_device_index_ + 1) % number_of_steps_per_ring_;
const auto recv_rank = (my_device_index_ - 1 + number_of_steps_per_ring_) %
number_of_steps_per_ring_;

for ([[maybe_unused]] const auto& _ :
c10::irange(params.number_of_iterations)) {
initializeIO();

for (auto i : c10::irange(number_of_rings_)) {
c10::intrusive_ptr<c10d::Work> comms_req = nullptr;
for (auto j : c10::irange(number_of_steps_per_ring_)) {
int64_t stream_index = (i + j) % streams.size();
setCurrentCUDAStream(streams.at(stream_index));

auto slice_index = (my_device_index_ - j + number_of_steps_per_ring_) %
number_of_steps_per_ring_;
auto next_slice_index =
(my_device_index_ - j - 1 + number_of_steps_per_ring_) %
number_of_steps_per_ring_;
auto ta_j_curr_slice = ta_.select(0, slice_index).select(0, i);
auto ta_j_next_slice = ta_.select(0, next_slice_index).select(0, i);
auto tc_j = tc_unsharded_.select(0, slice_index).select(0, i);

if (comms_req != nullptr) {
comms_req->wait();
comms_req = nullptr;
}

// send & matmul current index
std::vector<at::Tensor> src = {ta_j_curr_slice};
std::vector<at::Tensor> dst = {ta_j_next_slice};
if (j < number_of_steps_per_ring_ - 1) {
world_communicator_->startCoalescing();
world_communicator_->send(src, send_rank, 0);
world_communicator_->recv(dst, recv_rank, 0);
comms_req = world_communicator_->endCoalescing();
}
torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_);
}
}
synchronizeStreams(streams);
validate();
}
}

} // namespace nvfuser
Loading