From 009e334af7ee713d015907c1103027282e74f3ef Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 13 Sep 2024 05:30:33 +0000 Subject: [PATCH] support reduce_scatter --- .../distributed/c10d/ProcessGroupXCCL.cpp | 198 +++++++++++++++++- .../distributed/c10d/ProcessGroupXCCL.hpp | 18 +- 2 files changed, 206 insertions(+), 10 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index e78bc4e49871b..4792ba86682f0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -272,6 +272,13 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple. You are probably using multiple " + "devices under one thread. The support for such usage has been deprecated. " + "For details, please refer to " + "https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. " + "ProcessGroupXCCL continues supporting multi-process and multi-thread modes."; + ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, @@ -376,7 +383,50 @@ void ProcessGroupXCCL::groupEnd() { --xcclActiveGroupCounter_; } -// align with good design single-device style, input_t and output_t due to +// TODO: wait p2p enable +static constexpr int CoalActive = 0x01, CoalColl = 0x02; +void ProcessGroupXCCL::startCoalescing() { + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + auto comm = coalescedComm_; + auto device = coalescedDevice_; + + const auto key = std::to_string(device.index()); + auto stream = xcclStreams_.at(key); + + auto work = initWork(device, rank_, optype); + work->blockingWait_ = blockingWait_; + + groupEnd(); + + work->xcclEndEvent_->record(stream); + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + +// align with single-device style, input_t and output_t due to // allgatherv need vector output template < typename Fn, @@ -399,6 +449,21 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + auto stream = xcclStreams_.at(key); c10::intrusive_ptr work; @@ -502,6 +567,21 @@ c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } + auto stream = xcclStreams_.at(key); c10::intrusive_ptr work; @@ -535,8 +615,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK( - tensors.size() == 1, "Expecting one tensor only but got multiple"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); check_xpu_single_tensor(tensor); return collective( @@ -600,8 +679,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( c10::intrusive_ptr ProcessGroupXCCL::broadcast( std::vector& tensors, const BroadcastOptions& opts) { - TORCH_CHECK( - tensors.size() == 1, "Expecting one tensor only but got multiple"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); if (tensor.is_complex()) { tensor = at::view_as_real(tensor); @@ -633,12 +711,46 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( OpType::BROADCAST); } +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _reduce_oop must have the same number of elements "); + } + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + OpType::REDUCE); +} + c10::intrusive_ptr ProcessGroupXCCL::allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts) { - TORCH_CHECK( - inputTensors.size() == 1, "Expecting one tensor only but got multiple"); + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); // @lint-ignore CLANGTIDY auto inputTensor = inputTensors.back(); check_xpu_single_tensor(inputTensor); @@ -802,6 +914,78 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( OpType::COALESCED); } +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + check_xpu_single_tensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + TORCH_CHECK( + !isFloat8Type(outputTensor.scalar_type()), + "Float8 dtypes are not currenlty supported for NCCL reductions"); + + bool same_size = check_same_size(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::reduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + return ret_evt; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 94ee71ab0190c..3b72fd4261f5c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -149,6 +149,12 @@ class TORCH_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + c10::intrusive_ptr endCoalescing(OpType optype); + std::shared_ptr getXCCLComm( const std::string& deviceKey, at::Device& device); @@ -221,6 +227,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, @@ -246,9 +257,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); - } + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputTensor, @@ -327,6 +336,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; + int coalescing_state_ = 0; + at::Device coalescedDevice_ = at::Device("xpu"); + std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; static thread_local uint64_t xcclActiveGroupCounter_; };