diff --git a/tests/cpp/test_multidevice_overlap.cpp b/tests/cpp/test_multidevice_overlap.cpp index 6ed8fbec4de..c93296aaa27 100644 --- a/tests/cpp/test_multidevice_overlap.cpp +++ b/tests/cpp/test_multidevice_overlap.cpp @@ -1084,13 +1084,13 @@ TEST_F( // send & matmul current index std::vector src = {ta_j_curr_slice}; std::vector dst = {ta_j_next_slice}; - torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_); if (j < number_of_steps_per_ring_ - 1) { world_communicator_->startCoalescing(); world_communicator_->send(src, send_rank, 0); world_communicator_->recv(dst, recv_rank, 0); comms_req = world_communicator_->endCoalescing(); } + torch::matmul_out(tc_j, ta_j_curr_slice, tb_unsharded_); } } synchronizeStreams(streams);