diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 4e0b7db3592093..847b01b7680e8a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -26,6 +26,22 @@ namespace c10d { namespace { + +// wait nonblocking implement +AutoXcclGroup::AutoXcclGroup() { + comm_ = nullptr; + ccl::group_start(); +} + +AutoNcclGroup::AutoNcclGroup(xcclComm_t comm) { + comm_ = comm; + ccl::group_start(); +} + +AutoNcclGroup::~AutoNcclGroup() noexcept(false) { + ccl::group_end(); +} + std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, @@ -85,6 +101,34 @@ void check_xpu_single_tensor(const at::Tensor& tensor) { } } +int64_t check_xpu_tensors_same_device(const std::vector& tensors) { + if (tensors.size() == 0) { + C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); + } + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_xpu() || t.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + if (!t.is_non_overlapping_and_dense()) { + C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); + } + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + ccl::datatype getXcclDataType(at::ScalarType type) { auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( @@ -132,6 +176,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { static std::mutex xcclCommDevIdxMapMutex; static std::unordered_map, int> xcclCommDevIdxMap; +// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no effect. +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; + ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, @@ -240,6 +287,16 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( return it->second; } +void ProcessGroupXCCL::groupStart() { + ccl::group_start(); + ++xcclActiveGroupCounter_; +} + +void ProcessGroupXCCL::groupEnd() { + ccl::group_end(); + --xcclActiveGroupCounter_; +} + template c10::intrusive_ptr ProcessGroupXCCL::collective( at::Tensor& input, @@ -302,6 +359,53 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( opType); } +template +c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType) { + + using traits = function_traits; + using attr_t = typename traits::template arg<2>::type; + attr_t attr = ccl::create_operation_attr(); + + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + auto stream = xcclStreams_.at(key); + auto ccl_stream = ccl::create_stream(stream.queue()); + + c10::intrusive_ptr work; + + work = initWork(device, rank_, opType); + + work->outputs_ = + std::make_shared>(std::move(outputs)); + + { + AutoXcclGroup xccl_group_guard(comm); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], attr, *comm, ccl_stream); + } + } + + work->xcclEndEvent_->record(stream); + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + + return work; + +} + c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { @@ -334,6 +438,39 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( OpType::ALLREDUCE); } +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) { + check_xpu_tensors_same_device(tensors); + TORCH_CHECK( + !isFloat8Type(tensors.back().scalar_type()), + "Float8 dtypes are not currenlty supported for XCCL reductions"); + + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + ccl::stream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + stream, + attr); + return ret_evt; + }, + OpType::COALESCED); +} + } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 2f16df6450fe62..f2f9bb10022745 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -35,6 +35,15 @@ using xcclComm_t = ccl::communicator; using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; +namespace { +struct AutoXcclGroup { + AutoXcclGroup(); + AutoXcclGroup(xcclComm_t comm); + ~AutoXcclGroup() noexcept(false); + xcclComm_t comm_; +}; +} // namespace + class TORCH_API ProcessGroupXCCL : public Backend { public: class WorkXCCL : public Work { @@ -126,6 +135,13 @@ class TORCH_API ProcessGroupXCCL : public Backend { PostProcess post, OpType opType); + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType); + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; @@ -133,9 +149,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); - } + AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, @@ -238,6 +252,10 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); } + void groupStart(); + + void groupEnd(); + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, @@ -259,6 +277,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; + static thread_local uint64_t xcclActiveGroupCounter_; + uint64_t seqCollective_{0}; }; } // namespace c10d