diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 3a795c817bec22..8867904d92e24c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -19,15 +19,6 @@ namespace c10d { namespace { -// wait nonblocking implement -AutoXcclGroup::AutoXcclGroup() { - ccl::group_start(); -} - -AutoXcclGroup::~AutoXcclGroup() noexcept(false) { - ccl::group_end(); -} - std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, @@ -99,44 +90,32 @@ bool check_same_size(const std::vector& input_tensors) { return true; } -void check_xpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); +void check_xpu_single_tensor(const at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } - // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } } int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); - } + TORCH_CHECK_WITH(ValueError, tensors.size() == 0, "Tensor list must be nonempty"); const auto& first = tensors.front(); int64_t total_numel = 0; for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + if (!t.is_xpu() || t.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } if (t.scalar_type() != first.scalar_type()) { C10_THROW_ERROR(TypeError, "Tensors must have identical type"); } - if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); - } + // todo: zl_debug do we need it? +// if (!t.is_non_overlapping_and_dense()) { +// C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); +// } TORCH_CHECK_WITH( ValueError, t.get_device() == tensors[0].get_device(), @@ -147,7 +126,11 @@ int64_t check_xpu_tensors_same_device(const std::vector& tensors) { return total_numel; } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType(at::ScalarType type, bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -183,18 +166,6 @@ void syncStream( xcclEvent.block(xcclStream); } -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { - switch (reduceOp) { - case ReduceOp::SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - } // namespace static std::mutex xcclCommDevIdxMapMutex; @@ -230,6 +201,11 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { return false; } +void ProcessGroupXCCL::WorkXCCL::abort() { + std::lock_guard lock(mutex_); + devXCCLCommMap_.erase(xcclComm_); +} + void ProcessGroupXCCL::WorkXCCL::synchronize() { synchronizeInternal(kNoTimeout); } @@ -248,9 +224,9 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( "Work ran for ", timeElapsed.count(), " milliseconds before timing out."); - TORCH_CHECK(false, exceptionMsg) + abort(); + TORCH_CHECK(false, exceptionMsg); } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -335,10 +311,11 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( at::xpu::OptionalXPUGuard gpuGuard(device); - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_end(); - } +// todo: I think we don't need it now, with Coalesced manager +// for (const auto i : c10::irange(xcclActiveGroupCounter_)) { +// (void)i; +// ccl::group_end(); +// } int numRanks, rank; if (!singleP2POp) { @@ -365,32 +342,18 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( { std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, XCCLComm); - } - - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_start(); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); } - xcclStreams_.emplace(deviceKey, std::move(stream)); - xcclEvents_.emplace(deviceKey, at::xpu::XPUEvent()); - - auto it = inInitializationCommMap_.find(deviceKey); - if (it != inInitializationCommMap_.end()) { - devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - xcclCommDevIdxMapMutex.lock(); - xcclCommDevIdxMap.emplace(XCCLComm, device.index()); - xcclCommDevIdxMapMutex.unlock(); - } +// for (const auto i : c10::irange(xcclActiveGroupCounter_)) { +// (void)i; +// ccl::group_start(); +// } - it = devXCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); - return it->second; + return XCCLComm; } void ProcessGroupXCCL::groupStart() { @@ -403,7 +366,6 @@ void ProcessGroupXCCL::groupEnd() { --xcclActiveGroupCounter_; } -// TODO: wait p2p enable static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; void ProcessGroupXCCL::startCoalescing() { coalescedDevice_.set_index(-1); @@ -492,10 +454,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto& input : inputs) { c10::xpu::XPUCachingAllocator::recordStream( input.storage().data_ptr(), stream); + fn(inputs[i], outputs[i], attr, *comm, stream); } - fn(inputs[0], outputs[0], attr, *comm, stream); - post(stream, work); if (!coalescing_state_) { @@ -513,105 +474,12 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType); -} - -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - opType); -} - template -c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType) { - using traits = function_traits; - using attr_t = typename traits::template arg<2>::type; - attr_t attr = ccl::create_operation_attr(); - - auto device = inputs[0].device(); - const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } - - auto stream = xcclStreams_.at(key); - syncStream(device, xcclEvents_[key], stream); - - c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - - work->outputs_ = std::make_shared>(outputs); - - at::xpu::OptionalXPUGuard gpuGuard(device); - - { - AutoXcclGroup xccl_group_guard; - for (const auto i : c10::irange(inputs.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], attr, *comm, stream); - } - } - - work->xcclEndEvent_->record(stream); - - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - work->blockingWait_ = blockingWait_; - - return work; -} - -template c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( at::Tensor& tensor, Fn fn, int peer, - OpType opType, - PreProcess pre, - PostProcess post) { + OpType opType) { using traits = function_traits; using attr_t = typename traits::template arg<1>::type; attr_t attr = ccl::create_operation_attr(); @@ -655,60 +523,40 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( auto stream = xcclStreams_.at(key); syncStream(device, xcclEvents_[key], stream); - c10::intrusive_ptr work; if (!coalescing_state_) { + c10::intrusive_ptr work; work = initWork(device, rank_, opType); work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); - } - at::xpu::OptionalXPUGuard gpuGuard(device); - - if (!coalescing_state_) { - pre(stream, work); - } + at::xpu::OptionalXPUGuard gpuGuard(device); - c10::xpu::XPUCachingAllocator::recordStream( + c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, attr, *comm, stream, p2pTargetRank); - - if (!coalescing_state_) { - post(stream); + fn(tensor, attr, *comm, stream, p2pTargetRank); work->xcclEndEvent_->record(stream); work->blockingWait_ = blockingWait_; - - { - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); return work; } else { + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, attr, *comm, stream, p2pTargetRank); + return nullptr; } } -template -c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType) { - return pointToPoint( - tensor, - fn, - peer, - opType, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&) {}); -} - c10::intrusive_ptr ProcessGroupXCCL::send( std::vector& tensors, int dstRank, @@ -725,9 +573,8 @@ c10::intrusive_ptr ProcessGroupXCCL::send( xcclComm_t& comm, at::xpu::XPUStream& stream, int dst) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(input.scalar_type()); - ret_evt = ccl::send( + ccl::send( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -735,7 +582,7 @@ c10::intrusive_ptr ProcessGroupXCCL::send( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, dstRank, OpType::SEND); @@ -758,9 +605,8 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( xcclComm_t& comm, at::xpu::XPUStream& stream, int src) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(output.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( output.data_ptr(), (size_t)output.numel(), xcclDataType, @@ -768,7 +614,7 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, srcRank, OpType::RECV); @@ -837,13 +683,12 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do receive - ret_evt = ccl::recv( + ccl::recv( outputs[r].data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -857,7 +702,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } else { // do send - ret_evt = ccl::send( + ccl::send( inputTensor.data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -865,13 +710,9 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( comm, ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::GATHER); } @@ -938,14 +779,13 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( } } { - ccl::event ret_evt; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do send size_t send_count = inputs[r].numel(); auto send_type = getXcclDataType(inputs[r].scalar_type()); - ret_evt = ccl::send( + ccl::send( inputs[r].data_ptr(), send_count, send_type, @@ -961,7 +801,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( // do receive size_t recv_count = outputTensor.numel(); auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( outputTensor.data_ptr(), recv_count, recv_type, @@ -970,13 +810,9 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::SCATTER); } @@ -991,10 +827,9 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1003,7 +838,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLREDUCE); } @@ -1011,31 +846,38 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } - check_xpu_single_tensor(tensor); - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); - return allreduce_impl(tensor, opts); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue()), + attr); + return; + }, + OpType::ALLREDUCE); } c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { check_xpu_tensors_same_device(tensors); - TORCH_CHECK( - !isFloat8Type(tensors.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( tensors, @@ -1045,10 +887,9 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1057,7 +898,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::COALESCED); } @@ -1067,9 +908,6 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( const BroadcastOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); const auto root = opts.rootRank + opts.rootTensor; @@ -1083,8 +921,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -1092,7 +929,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::BROADCAST); } @@ -1101,11 +938,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _broadcast_oop must have the same number of elements "); + const auto root = opts.rootRank + opts.rootTensor; return collective( inputTensor, @@ -1116,8 +951,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -1125,7 +959,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::BROADCAST); } @@ -1134,16 +968,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } + auto tensor = tensors[0]; check_xpu_single_tensor(tensor); return collective( @@ -1155,10 +980,11 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + const auto xcclStream = getXcclStream(stream.queue); + + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1167,7 +993,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::REDUCE); } @@ -1176,11 +1002,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _reduce_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( inputTensor, outputTensor, @@ -1190,10 +1014,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1202,7 +1025,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::REDUCE); } @@ -1212,39 +1035,47 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( std::vector& inputTensors, const AllgatherOptions& opts) { TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); + auto inputTensor = inputTensors[0]; check_xpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); + std::vector& outputTensors_ = outputTensors[0]; - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); - - return collective( + return collective( inputTensor, - outputFlattened, + outputTensors_, [&](at::Tensor& input, - at::Tensor& output, + const std::vector& outputs, ccl::allgather_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, - ret_evt = ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue()), - attr); - return ret_evt; + std::vector recvCounts(outputs.size(), 0); + std::transform(outputs.begin(), outputs.end(), recvCounts.begin(), + [](const at::Tensor& t) { + return t.numel(); + }); + + TORCH_CHECK((size_t)input.numel() == recvCounts[rank], "allgather: send and recv count doesn't match"); + std::vector recvBufs(outputs.size(), nullptr); + std::transform(outputs.begin(), outputs.end(), recvBufs.begin(), + [](const at::Tensor& t) { + return t.data_ptr(); + }); + + ccl::allgatherv(input.data_ptr(), + (size_t) input.numel(), + recvBufs, + recvCounts, + cclDatatypes.at(input.scalar_type()), + comm, + ccl::create_stream(stream.queue())); + return; }, [](at::xpu::XPUStream&, c10::intrusive_ptr& work) {}, @@ -1259,19 +1090,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( } }, OpType::ALLGATHER); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } + } c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( @@ -1281,16 +1100,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( check_xpu_single_tensor(input_tensor); check_xpu_single_tensor(output_tensor); - if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); - } - - if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, - "output tensor size must be equal to world_size times input tensor size"); - } + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, input_tensor.numel() * size_ != output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); return collective( input_tensor, @@ -1303,8 +1116,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1312,7 +1124,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::_ALLGATHER_BASE); } @@ -1330,8 +1142,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1339,7 +1150,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::COALESCED); } @@ -1354,9 +1165,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( check_xpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); - TORCH_CHECK( - !isFloat8Type(outputTensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); bool same_size = check_same_size(inputTensors_); if (same_size) { @@ -1372,10 +1180,9 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1383,7 +1190,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, [&](at::xpu::XPUStream& Stream, c10::intrusive_ptr& work) { @@ -1420,22 +1227,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); - } - - if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, - "input tensor must be the same size as output size times world size"); - } - - // @lint-ignore CLANGTIDY - const auto& tensor = outputTensor; - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, inputTensor.numel() != outputTensor.numel() * size_, + "input tensor size must be equal to world_size times output tensor size"); return collective( inputTensor, @@ -1447,10 +1242,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1458,7 +1252,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::_REDUCE_SCATTER_BASE); } @@ -1467,9 +1261,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { - TORCH_CHECK( - !isFloat8Type(inputs.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( inputs, outputs, @@ -1480,10 +1271,9 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1491,7 +1281,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::COALESCED); } @@ -1511,7 +1301,7 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { barDevIdx = static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); } - + // todo: use barrier instead of allreduce TORCH_CHECK_WITH( ValueError, barDevIdx >= 0, @@ -1556,8 +1346,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoall( + ccl::alltoall( input.data_ptr(), output.data_ptr(), (size_t)output.numel() / comm.size(), @@ -1565,7 +1354,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLTOALL_BASE); } else { @@ -1599,9 +1388,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); } auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - - ret_evt = ccl::alltoallv( + ccl::alltoallv( input.data_ptr(), sendCounts, output.data_ptr(), @@ -1610,7 +1397,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLTOALL_BASE); } @@ -1663,7 +1450,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - ccl::event ret_evt; ret_evt = ccl::alltoallv( flatInput.data_ptr(), sendCounts, @@ -1675,7 +1461,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( attr); if (!isOutputFlat) { - ret_evt.wait(); + ret_evt.wait(); // todo: zl_debug why do we need wait here? auto flatOutputSplits = flatOutput.split_with_sizes( c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), 0); @@ -1686,11 +1472,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } stream.synchronize(); - return ret_evt; - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + return; }, OpType::ALLTOALL); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 88e87ed6713ddf..22a1f62b6e3b25 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -64,9 +64,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); } - void abort() override { - TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); - } + void abort() override; void synchronize() override; @@ -134,7 +132,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType); + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + opType); + } template c10::intrusive_ptr collective( @@ -143,7 +151,28 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType) { + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + opType); + } template c10::intrusive_ptr collective( @@ -159,7 +188,19 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& input, std::vector& output, Fn fn, - OpType opType); + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + ccl::group_start(); + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType); + } template c10::intrusive_ptr pointToPoint( @@ -168,15 +209,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { int peer, OpType opType); - template - c10::intrusive_ptr pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - PreProcess pre, - PostProcess post); - c10::intrusive_ptr allreduce_impl( at::Tensor& tensor, const AllreduceOptions& opts = AllreduceOptions()); @@ -287,8 +319,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: std::unordered_map xcclStreams_; std::unordered_map xcclEvents_; - std::unordered_map> - inInitializationCommMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_;