diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index a89a7c48a01ff..956e80482af28 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -82,40 +82,38 @@ void check_xpu_single_tensor( const at::Tensor& tensor, const bool p2p = false // whether operation is a P2P operation ) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); - } - // Skip the following requirements for P2P operations - if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + + // Skip the following requirements for P2P operations + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } } } } - int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); - } + TORCH_CHECK_WITH( + ValueError, tensors.size() == 0, "Tensor list must be nonempty"); const auto& first = tensors.front(); int64_t total_numel = 0; for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); } if (t.scalar_type() != first.scalar_type()) { C10_THROW_ERROR(TypeError, "Tensors must have identical type"); } - if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); - } TORCH_CHECK_WITH( ValueError, t.get_device() == tensors[0].get_device(), @@ -126,7 +124,12 @@ int64_t check_xpu_tensors_same_device(const std::vector& tensors) { return total_numel; } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -158,18 +161,6 @@ void syncStream( xcclEvent.block(xcclStream); } -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { - switch (reduceOp) { - case ReduceOp::SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; @@ -286,7 +277,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( usedDeviceIdxs_.insert(device.index()); { - // todo: why do we need mutex here? std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { return devXCCLCommMap_[deviceKey]; @@ -301,11 +291,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( at::xpu::OptionalXPUGuard gpuGuard(device); - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_end(); - } - int numRanks, rank; if (!singleP2POp) { numRanks = getSize(); @@ -329,37 +314,14 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( auto xccl_kvs = get_kvs(rank_, *store_); auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - std::shared_ptr XCCLComm = - std::make_shared(std::move(comms[0])); + XCCLComm = std::make_shared(std::move(comms[0])); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace(deviceKey, std::move(stream)); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_start(); - } - - xcclStreams_.emplace(deviceKey, std::move(stream)); - xcclEvents_.emplace(deviceKey, at::xpu::XPUEvent()); - - auto it = inInitializationCommMap_.find(deviceKey); - if (it != inInitializationCommMap_.end()) { - devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - xcclCommDevIdxMapMutex.lock(); - xcclCommDevIdxMap.emplace(XCCLComm, device.index()); - xcclCommDevIdxMapMutex.unlock(); - } - - it = devXCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); - - return it->second; + return XCCLComm; } void ProcessGroupXCCL::groupStart() { @@ -396,7 +358,7 @@ c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { auto device = coalescedDevice_; const auto key = std::to_string(device.index()); - auto stream = xcclStreams_.at(key); + auto stream = xcclStreamsMap_.at(key); auto work = initWork(device, rank_, optype); work->blockingWait_ = blockingWait_; @@ -422,7 +384,8 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( Fn fn, PreProcess pre, PostProcess post, - OpType opType) { + OpType opType, + const char* profilingTitle) { auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device, opType); @@ -442,8 +405,8 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( } } - auto stream = xcclStreams_.at(key); - syncStream(device, xcclEvents_[key], stream); + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; work = initWork(device, rank_, opType); @@ -454,13 +417,12 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( pre(stream, work); - for (const auto& input : inputs) { + for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); } - fn(inputs[0], outputs[0], *comm, stream); - post(stream, work); if (!coalescing_state_) { @@ -478,97 +440,13 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType); -} - template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - opType); -} - -template -c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType) { - auto device = inputs[0].device(); - const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } - - auto stream = xcclStreamsMap_.at(key); - syncStream(device, xcclEventsMap_[key], stream); - - c10::intrusive_ptr work; - work = initWork(device, rank_, opType, profilingTitle); - work->outputs_ = std::make_shared>(outputs); - - at::xpu::OptionalXPUGuard gpuGuard(device); - pre(stream, work); - for (const auto i : c10::irange(inputs.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], *comm, stream); - } - post(stream, work); - - work->xcclEndEvent_->record(stream); - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - work->blockingWait_ = blockingWait_; - - return work; -} - -template c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( at::Tensor& tensor, Fn fn, int peer, OpType opType, - PreProcess pre, - PostProcess post) { + const char* profilingTitle) { auto device = tensor.device(); std::string key; int p2pRank = 0, p2pTargetRank = 0; @@ -605,63 +483,43 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } } - auto stream = xcclStreams_.at(key); - syncStream(device, xcclEvents_[key], stream); + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); - c10::intrusive_ptr work; if (!coalescing_state_) { + c10::intrusive_ptr work; work = initWork(device, rank_, opType); work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); - } - - at::xpu::OptionalXPUGuard gpuGuard(device); - - if (!coalescing_state_) { - pre(stream, work); - } - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); + at::xpu::OptionalXPUGuard gpuGuard(device); - fn(tensor, *comm, stream, p2pTargetRank); + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); - if (!coalescing_state_) { - post(stream); + fn(tensor, *comm, stream, p2pTargetRank); work->xcclEndEvent_->record(stream); work->blockingWait_ = blockingWait_; - - { - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); return work; } else { + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, *comm, stream, p2pTargetRank); + return nullptr; } } -template -c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType) { - return pointToPoint( - tensor, - fn, - peer, - opType, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&) {}); -} - c10::intrusive_ptr ProcessGroupXCCL::send( std::vector& tensors, int dstRank, @@ -677,19 +535,19 @@ c10::intrusive_ptr ProcessGroupXCCL::send( xcclComm_t& comm, at::xpu::XPUStream& stream, int dst) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(input.scalar_type()); - ret_evt = ccl::send( + ccl::send( input.data_ptr(), (size_t)input.numel(), xcclDataType, dst, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, dstRank, - OpType::SEND); + OpType::SEND, + c10::str("xccl:send ", rank_, "->", dstRank).c_str()); return ret; } @@ -708,19 +566,19 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( xcclComm_t& comm, at::xpu::XPUStream& stream, int src) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(output.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( output.data_ptr(), (size_t)output.numel(), xcclDataType, src, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, srcRank, - OpType::RECV); + OpType::RECV, + c10::str("xccl:recv ", rank_, "<-", srcRank).c_str()); return ret; } @@ -785,13 +643,12 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do receive - ret_evt = ccl::recv( + ccl::recv( outputs[r].data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -805,7 +662,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } else { // do send - ret_evt = ccl::send( + ccl::send( inputTensor.data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -813,13 +670,9 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( comm, ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::GATHER); } @@ -885,14 +738,13 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( } } { - ccl::event ret_evt; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do send size_t send_count = inputs[r].numel(); auto send_type = getXcclDataType(inputs[r].scalar_type()); - ret_evt = ccl::send( + ccl::send( inputs[r].data_ptr(), send_count, send_type, @@ -908,7 +760,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( // do receive size_t recv_count = outputTensor.numel(); auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( outputTensor.data_ptr(), recv_count, recv_type, @@ -917,13 +769,9 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::SCATTER); } @@ -937,7 +785,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( at::Tensor& output, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); auto ccl_stream = ccl::create_stream(stream.queue()); ccl::allreduce( @@ -948,7 +796,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::ALLREDUCE, "xccl:all_reduce"); @@ -959,30 +807,35 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( const AllreduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); - return allreduce_impl(tensor, opts); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); } c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { check_xpu_tensors_same_device(tensors); - TORCH_CHECK( - !isFloat8Type(tensors.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); - return collectiveCoalesced( tensors, tensors, @@ -990,10 +843,9 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( at::Tensor& output, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1001,9 +853,10 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::COALESCED); + OpType::COALESCED, + "xccl:allreduce_coalesced"); } c10::intrusive_ptr ProcessGroupXCCL::broadcast( @@ -1011,9 +864,6 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( const BroadcastOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); const auto root = opts.rootRank + opts.rootTensor; @@ -1026,17 +876,17 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::BROADCAST); + OpType::BROADCAST, + "nccl:broadcast"); } c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( @@ -1057,33 +907,24 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::BROADCAST); + OpType::BROADCAST, + "xccl:_broadcast_oop"); } c10::intrusive_ptr ProcessGroupXCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); return collective( @@ -1094,10 +935,9 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1106,20 +946,20 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::REDUCE); + OpType::REDUCE, + "xccl:reduce"); } c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _reduce_oop must have the same number of elements "); - } + TORCH_CHECK_WITH( + ValueError, + outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); return collective( inputTensor, outputTensor, @@ -1128,10 +968,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1140,9 +979,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::REDUCE); + OpType::REDUCE, + "xccl:_reduce_oop"); } c10::intrusive_ptr ProcessGroupXCCL::allgather( @@ -1171,16 +1011,14 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, [](at::xpu::XPUStream&, c10::intrusive_ptr& work) {}, @@ -1194,7 +1032,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( outputTensors_[j].copy_(outputFlattened[j], true); } }, - OpType::ALLGATHER); + OpType::ALLGATHER, + "xccl:all_gather"); } else { const auto num_reduces = outputTensors_.size(); startCoalescing(); @@ -1217,16 +1056,14 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( check_xpu_single_tensor(input_tensor); check_xpu_single_tensor(output_tensor); - if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); - } - - if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, - "output tensor size must be equal to world_size times input tensor size"); - } + TORCH_CHECK_WITH( + TypeError, + input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + input_tensor.numel() * size_ != output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); return collective( input_tensor, @@ -1238,17 +1075,17 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::_ALLGATHER_BASE); + OpType::_ALLGATHER_BASE, + "xccl:_all_gather_base"); } c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( @@ -1263,17 +1100,17 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::COALESCED); + OpType::COALESCED, + "xccl:all_gather_into_tensor_coalesced"); } c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( @@ -1286,9 +1123,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( check_xpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); - TORCH_CHECK( - !isFloat8Type(outputTensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); bool same_size = check_same_size(inputTensors_); if (same_size) { @@ -1303,10 +1137,9 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1314,7 +1147,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, [&](at::xpu::XPUStream& Stream, c10::intrusive_ptr& work) { @@ -1328,7 +1161,8 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( }, [&](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, - OpType::REDUCE_SCATTER); + OpType::REDUCE_SCATTER, + "xccl:reduce_scatter"); } else { const auto num_reduces = inputTensors_.size(); startCoalescing(); @@ -1351,22 +1185,14 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); - } - - if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, - "input tensor must be the same size as output size times world size"); - } - - // @lint-ignore CLANGTIDY - const auto& tensor = outputTensor; - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK_WITH( + TypeError, + inputTensor.dtype() != outputTensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + inputTensor.numel() != outputTensor.numel() * size_, + "input tensor size must be equal to world_size times output tensor size"); return collective( inputTensor, @@ -1377,10 +1203,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1388,18 +1213,16 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::_REDUCE_SCATTER_BASE); + OpType::_REDUCE_SCATTER_BASE, + "xccl:_reduce_scatter_base"); } c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { - TORCH_CHECK( - !isFloat8Type(inputs.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( inputs, outputs, @@ -1409,10 +1232,9 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1420,9 +1242,10 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::COALESCED); + OpType::COALESCED, + "xccl:reduce_scatter_tensor_coalesced"); } c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { @@ -1441,6 +1264,7 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); } + // todo: use barrier instead of allreduce TORCH_CHECK_WITH( ValueError, barDevIdx >= 0, @@ -1484,17 +1308,17 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoall( + ccl::alltoall( input.data_ptr(), output.data_ptr(), (size_t)output.numel() / comm.size(), xcclDataType, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::ALLTOALL_BASE); + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); @@ -1525,9 +1349,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); } auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - - ret_evt = ccl::alltoallv( + ccl::alltoallv( input.data_ptr(), sendCounts, output.data_ptr(), @@ -1535,9 +1357,10 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( xcclDataType, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, - OpType::ALLTOALL_BASE); + OpType::ALLTOALL_BASE, + "xccl:all_to_all"); } } @@ -1607,15 +1430,11 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); } } - stream.synchronize(); - return ret_evt; - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + return; }, - OpType::ALLTOALL); + OpType::ALLTOALL, + "xccl:all_to_all"); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 0f2b2738a4b77..c8fa11442c692 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -21,18 +21,12 @@ #include namespace c10d { -namespace { -struct AutoXcclGroup { - AutoXcclGroup(); - ~AutoXcclGroup() noexcept(false); -}; -} // namespace - static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -129,28 +123,50 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; return collective( - inputs, - outputs, + input, + output, fn, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, - opType); + opType, + profilingTitle); } template c10::intrusive_ptr collective( - std::vector& inputs, - std::vector& outputs, + at::Tensor& input, + at::Tensor& output, Fn fn, PreProcess pre, PostProcess post, OpType opType, - const char* profilingTitle = nullptr); + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } template c10::intrusive_ptr collective( @@ -159,30 +175,39 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr); template c10::intrusive_ptr collectiveCoalesced( std::vector& input, std::vector& output, Fn fn, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_start(); + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType, + profilingTitle); + } template - c10::intrusive_ptr pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType); - - template c10::intrusive_ptr pointToPoint( at::Tensor& tensor, Fn fn, int peer, OpType opType, - PreProcess pre, - PostProcess post); + const char* profilingTitle = nullptr); c10::intrusive_ptr allreduce_impl( at::Tensor& tensor, @@ -285,10 +310,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { const ScatterOptions& opts = ScatterOptions()) override; protected: - std::unordered_map xcclStreams_; - std::unordered_map xcclEvents_; - std::unordered_map> - inInitializationCommMap_; + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_;