From 748e5477ea4132717001bc51b7d29a18a01bce6c Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Thu, 10 Oct 2024 14:43:32 +0800 Subject: [PATCH 1/4] refine collective API --- .../distributed/c10d/ProcessGroupXCCL.cpp | 119 +----------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 43 ++++++- 2 files changed, 40 insertions(+), 122 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 3a795c817bec2..15610c655da52 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -19,15 +19,6 @@ namespace c10d { namespace { -// wait nonblocking implement -AutoXcclGroup::AutoXcclGroup() { - ccl::group_start(); -} - -AutoXcclGroup::~AutoXcclGroup() noexcept(false) { - ccl::group_end(); -} - std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, @@ -492,10 +483,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto& input : inputs) { c10::xpu::XPUCachingAllocator::recordStream( input.storage().data_ptr(), stream); + fn(inputs[i], outputs[i], attr, *comm, stream) } - fn(inputs[0], outputs[0], attr, *comm, stream); - post(stream, work); if (!coalescing_state_) { @@ -513,97 +503,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType); -} - -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - opType); -} - -template -c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType) { - using traits = function_traits; - using attr_t = typename traits::template arg<2>::type; - attr_t attr = ccl::create_operation_attr(); - - auto device = inputs[0].device(); - const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } - - auto stream = xcclStreams_.at(key); - syncStream(device, xcclEvents_[key], stream); - - c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - - work->outputs_ = std::make_shared>(outputs); - - at::xpu::OptionalXPUGuard gpuGuard(device); - - { - AutoXcclGroup xccl_group_guard; - for (const auto i : c10::irange(inputs.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], attr, *comm, stream); - } - } - - work->xcclEndEvent_->record(stream); - - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - work->blockingWait_ = blockingWait_; - - return work; -} - template c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( at::Tensor& tensor, @@ -693,22 +592,6 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( } } -template -c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType) { - return pointToPoint( - tensor, - fn, - peer, - opType, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&) {}); -} - c10::intrusive_ptr ProcessGroupXCCL::send( std::vector& tensors, int dstRank, diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 88e87ed6713dd..da28bcc0b3c02 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -134,7 +134,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType); + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + opType); + } template c10::intrusive_ptr collective( @@ -143,7 +153,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType); + } template c10::intrusive_ptr collective( @@ -159,14 +173,35 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& input, std::vector& output, Fn fn, - OpType opType); + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + ccl::group_start(); + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType); + } template c10::intrusive_ptr pointToPoint( at::Tensor& tensor, Fn fn, int peer, - OpType opType); + OpType opType) { + return pointToPoint( + tensor, + fn, + peer, + opType, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&) {}); + } template c10::intrusive_ptr pointToPoint( From 69c22f9cf0a8707cca9ef74fe3a49ce92cb3eb03 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Thu, 10 Oct 2024 22:59:11 +0800 Subject: [PATCH 2/4] remove unneeded --- .../distributed/c10d/ProcessGroupXCCL.cpp | 191 +++++++++--------- .../distributed/c10d/ProcessGroupXCCL.hpp | 23 ++- 2 files changed, 108 insertions(+), 106 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 15610c655da52..c51e3a8e7c426 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -43,6 +43,36 @@ std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; +XCCL_KVS kvs; +std::mutex kvs_mutex; + +XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + + return kvs; +} + bool computeLengthsAndCheckAndGetFlat( const std::vector& tensors, std::vector& lengths, @@ -221,6 +251,11 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { return false; } +void ProcessGroupXCCL::WorkXCCL::abort() { + std::lock_guard lock(mutex_); + devXCCLCommMap_.erase(xcclComm_); +} + void ProcessGroupXCCL::WorkXCCL::synchronize() { synchronizeInternal(kNoTimeout); } @@ -239,9 +274,9 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( "Work ran for ", timeElapsed.count(), " milliseconds before timing out."); - TORCH_CHECK(false, exceptionMsg) + abort(); + TORCH_CHECK(false, exceptionMsg); } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -326,10 +361,11 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( at::xpu::OptionalXPUGuard gpuGuard(device); - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_end(); - } +// todo: I think we don't need it now, with Coalesced manager +// for (const auto i : c10::irange(xcclActiveGroupCounter_)) { +// (void)i; +// ccl::group_end(); +// } int numRanks, rank; if (!singleP2POp) { @@ -356,32 +392,18 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( { std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, XCCLComm); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); } - for (const auto i : c10::irange(xcclActiveGroupCounter_)) { - (void)i; - ccl::group_start(); - } +// for (const auto i : c10::irange(xcclActiveGroupCounter_)) { +// (void)i; +// ccl::group_start(); +// } xcclStreams_.emplace(deviceKey, std::move(stream)); xcclEvents_.emplace(deviceKey, at::xpu::XPUEvent()); - auto it = inInitializationCommMap_.find(deviceKey); - if (it != inInitializationCommMap_.end()) { - devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - xcclCommDevIdxMapMutex.lock(); - xcclCommDevIdxMap.emplace(XCCLComm, device.index()); - xcclCommDevIdxMapMutex.unlock(); - } - - it = devXCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); - - return it->second; + return XCCLComm; } void ProcessGroupXCCL::groupStart() { @@ -483,7 +505,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto& input : inputs) { c10::xpu::XPUCachingAllocator::recordStream( input.storage().data_ptr(), stream); - fn(inputs[i], outputs[i], attr, *comm, stream) + fn(inputs[i], outputs[i], attr, *comm, stream); } post(stream, work); @@ -608,9 +630,8 @@ c10::intrusive_ptr ProcessGroupXCCL::send( xcclComm_t& comm, at::xpu::XPUStream& stream, int dst) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(input.scalar_type()); - ret_evt = ccl::send( + ccl::send( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -618,7 +639,7 @@ c10::intrusive_ptr ProcessGroupXCCL::send( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, dstRank, OpType::SEND); @@ -641,9 +662,8 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( xcclComm_t& comm, at::xpu::XPUStream& stream, int src) { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(output.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( output.data_ptr(), (size_t)output.numel(), xcclDataType, @@ -651,7 +671,7 @@ c10::intrusive_ptr ProcessGroupXCCL::recv( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, srcRank, OpType::RECV); @@ -720,13 +740,12 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } { - ccl::event ret_evt; auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do receive - ret_evt = ccl::recv( + ccl::recv( outputs[r].data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -740,7 +759,7 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( } } else { // do send - ret_evt = ccl::send( + ccl::send( inputTensor.data_ptr(), (size_t)inputTensor.numel(), xcclDataType, @@ -748,13 +767,9 @@ c10::intrusive_ptr ProcessGroupXCCL::gather( comm, ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::GATHER); } @@ -821,14 +836,13 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( } } { - ccl::event ret_evt; if (rank_ == root) { for (const auto r : c10::irange(size_)) { if (r != root) { // do send size_t send_count = inputs[r].numel(); auto send_type = getXcclDataType(inputs[r].scalar_type()); - ret_evt = ccl::send( + ccl::send( inputs[r].data_ptr(), send_count, send_type, @@ -844,7 +858,7 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( // do receive size_t recv_count = outputTensor.numel(); auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ret_evt = ccl::recv( + ccl::recv( outputTensor.data_ptr(), recv_count, recv_type, @@ -853,13 +867,9 @@ c10::intrusive_ptr ProcessGroupXCCL::scatter( ccl::create_stream(stream.queue())); } - return ret_evt; + return; } }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, OpType::SCATTER); } @@ -876,8 +886,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -886,7 +895,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLREDUCE); } @@ -930,8 +939,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -940,7 +948,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::COALESCED); } @@ -966,8 +974,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -975,7 +982,7 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::BROADCAST); } @@ -999,8 +1006,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::broadcast( + ccl::broadcast( input.data_ptr(), (size_t)input.numel(), xcclDataType, @@ -1008,7 +1014,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::BROADCAST); } @@ -1040,8 +1046,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type()); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1050,7 +1055,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::REDUCE); } @@ -1075,8 +1080,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( const int root = opts.rootRank + opts.rootTensor; const auto xcclDataType = getXcclDataType(input.scalar_type()); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce( + ccl::reduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1085,7 +1089,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( root, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::REDUCE); } @@ -1117,9 +1121,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1127,7 +1129,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, [](at::xpu::XPUStream&, c10::intrusive_ptr& work) {}, @@ -1186,8 +1188,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1195,7 +1196,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::_ALLGATHER_BASE); } @@ -1213,8 +1214,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::allgather( + ccl::allgather( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -1222,7 +1222,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::COALESCED); } @@ -1257,8 +1257,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1266,7 +1265,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, [&](at::xpu::XPUStream& Stream, c10::intrusive_ptr& work) { @@ -1332,8 +1331,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1341,7 +1339,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::_REDUCE_SCATTER_BASE); } @@ -1365,8 +1363,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::reduce_scatter( + ccl::reduce_scatter( input.data_ptr(), output.data_ptr(), (size_t)output.numel(), @@ -1374,7 +1371,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( xcclReduceOp, comm, ccl::create_stream(stream.queue())); - return ret_evt; + return; }, OpType::COALESCED); } @@ -1394,7 +1391,7 @@ c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { barDevIdx = static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); } - + // todo: use barrier instead of allreduce TORCH_CHECK_WITH( ValueError, barDevIdx >= 0, @@ -1439,8 +1436,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoall( + ccl::alltoall( input.data_ptr(), output.data_ptr(), (size_t)output.numel() / comm.size(), @@ -1448,7 +1444,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLTOALL_BASE); } else { @@ -1482,9 +1478,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); } auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::event ret_evt; - - ret_evt = ccl::alltoallv( + ccl::alltoallv( input.data_ptr(), sendCounts, output.data_ptr(), @@ -1493,7 +1487,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( comm, ccl::create_stream(stream.queue()), attr); - return ret_evt; + return; }, OpType::ALLTOALL_BASE); } @@ -1546,7 +1540,6 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - ccl::event ret_evt; ret_evt = ccl::alltoallv( flatInput.data_ptr(), sendCounts, @@ -1558,7 +1551,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( attr); if (!isOutputFlat) { - ret_evt.wait(); + ret_evt.wait(); // todo: zl_debug why do we need wait here? auto flatOutputSplits = flatOutput.split_with_sizes( c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), 0); @@ -1569,11 +1562,7 @@ c10::intrusive_ptr ProcessGroupXCCL::alltoall( } stream.synchronize(); - return ret_evt; - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + return; }, OpType::ALLTOALL); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index da28bcc0b3c02..9099181ffbfa9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -64,9 +64,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); } - void abort() override { - TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); - } + void abort() override; void synchronize() override; @@ -159,6 +157,23 @@ class TORCH_API ProcessGroupXCCL : public Backend { return collective(inputs, outputs, fn, pre, post, opType); } + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType) { + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + [](at::xpu::XPUStream&, c10::intrusive_ptr&) { + }, + opType); + } + template c10::intrusive_ptr collective( std::vector& inputs, @@ -322,8 +337,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: std::unordered_map xcclStreams_; std::unordered_map xcclEvents_; - std::unordered_map> - inInitializationCommMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; From a3b2b0a37aa18f36e88f9d4291e2dcc0ff8e5d54 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Fri, 11 Oct 2024 14:21:21 +0800 Subject: [PATCH 3/4] refine code --- .../distributed/c10d/ProcessGroupXCCL.cpp | 45 ++++++++----------- .../distributed/c10d/ProcessGroupXCCL.hpp | 20 +-------- 2 files changed, 20 insertions(+), 45 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index c51e3a8e7c426..7ed4502743299 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -416,7 +416,6 @@ void ProcessGroupXCCL::groupEnd() { --xcclActiveGroupCounter_; } -// TODO: wait p2p enable static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; void ProcessGroupXCCL::startCoalescing() { coalescedDevice_.set_index(-1); @@ -525,14 +524,12 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template +template c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( at::Tensor& tensor, Fn fn, int peer, - OpType opType, - PreProcess pre, - PostProcess post) { + OpType opType) { using traits = function_traits; using attr_t = typename traits::template arg<1>::type; attr_t attr = ccl::create_operation_attr(); @@ -576,40 +573,36 @@ c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( auto stream = xcclStreams_.at(key); syncStream(device, xcclEvents_[key], stream); - c10::intrusive_ptr work; if (!coalescing_state_) { + c10::intrusive_ptr work; work = initWork(device, rank_, opType); work->outputs_ = std::make_shared>(); work->outputs_->push_back(tensor); - } - at::xpu::OptionalXPUGuard gpuGuard(device); + at::xpu::OptionalXPUGuard gpuGuard(device); - if (!coalescing_state_) { - pre(stream, work); - } - - c10::xpu::XPUCachingAllocator::recordStream( + c10::xpu::XPUCachingAllocator::recordStream( tensor.storage().data_ptr(), stream); - fn(tensor, attr, *comm, stream, p2pTargetRank); - - if (!coalescing_state_) { - post(stream); + fn(tensor, attr, *comm, stream, p2pTargetRank); work->xcclEndEvent_->record(stream); work->blockingWait_ = blockingWait_; - - { - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - } + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); return work; } else { + at::xpu::OptionalXPUGuard gpuGuard(device); + + c10::xpu::XPUCachingAllocator::recordStream( + tensor.storage().data_ptr(), stream); + + fn(tensor, attr, *comm, stream, p2pTargetRank); + return nullptr; } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 9099181ffbfa9..22a1f62b6e3b2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -207,25 +207,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& tensor, Fn fn, int peer, - OpType opType) { - return pointToPoint( - tensor, - fn, - peer, - opType, - [](at::xpu::XPUStream&, c10::intrusive_ptr&) { - }, - [](at::xpu::XPUStream&) {}); - } - - template - c10::intrusive_ptr pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - PreProcess pre, - PostProcess post); + OpType opType); c10::intrusive_ptr allreduce_impl( at::Tensor& tensor, From fd6a11dbe83443d9abf7461ed6289494e007d5c5 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Fri, 11 Oct 2024 17:26:19 +0800 Subject: [PATCH 4/4] debug --- .../distributed/c10d/ProcessGroupXCCL.cpp | 287 +++++++----------- 1 file changed, 102 insertions(+), 185 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 7ed4502743299..8867904d92e24 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -43,36 +43,6 @@ std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - bool computeLengthsAndCheckAndGetFlat( const std::vector& tensors, std::vector& lengths, @@ -120,44 +90,32 @@ bool check_same_size(const std::vector& input_tensors) { return true; } -void check_xpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); +void check_xpu_single_tensor(const at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } - // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } } int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); - } + TORCH_CHECK_WITH(ValueError, tensors.size() == 0, "Tensor list must be nonempty"); const auto& first = tensors.front(); int64_t total_numel = 0; for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + if (!t.is_xpu() || t.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } if (t.scalar_type() != first.scalar_type()) { C10_THROW_ERROR(TypeError, "Tensors must have identical type"); } - if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); - } + // todo: zl_debug do we need it? +// if (!t.is_non_overlapping_and_dense()) { +// C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); +// } TORCH_CHECK_WITH( ValueError, t.get_device() == tensors[0].get_device(), @@ -168,7 +126,11 @@ int64_t check_xpu_tensors_same_device(const std::vector& tensors) { return total_numel; } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType(at::ScalarType type, bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -204,18 +166,6 @@ void syncStream( xcclEvent.block(xcclStream); } -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { - switch (reduceOp) { - case ReduceOp::SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - } // namespace static std::mutex xcclCommDevIdxMapMutex; @@ -400,8 +350,8 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( // ccl::group_start(); // } - xcclStreams_.emplace(deviceKey, std::move(stream)); - xcclEvents_.emplace(deviceKey, at::xpu::XPUEvent()); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; } @@ -877,7 +827,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::allreduce( input.data_ptr(), @@ -896,31 +846,38 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } - check_xpu_single_tensor(tensor); - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); - return allreduce_impl(tensor, opts); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue()), + attr); + return; + }, + OpType::ALLREDUCE); } c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { check_xpu_tensors_same_device(tensors); - TORCH_CHECK( - !isFloat8Type(tensors.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( tensors, @@ -930,7 +887,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::allreduce( input.data_ptr(), @@ -951,9 +908,6 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( const BroadcastOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); const auto root = opts.rootRank + opts.rootTensor; @@ -984,11 +938,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _broadcast_oop must have the same number of elements "); + const auto root = opts.rootRank + opts.rootTensor; return collective( inputTensor, @@ -1016,16 +968,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } + auto tensor = tensors[0]; check_xpu_single_tensor(tensor); return collective( @@ -1037,8 +980,10 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + const auto xcclStream = getXcclStream(stream.queue); + ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1057,11 +1002,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _reduce_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( inputTensor, outputTensor, @@ -1071,7 +1014,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce( input.data_ptr(), @@ -1092,37 +1035,47 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( std::vector& inputTensors, const AllgatherOptions& opts) { TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); + auto inputTensor = inputTensors[0]; check_xpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); - - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); + std::vector& outputTensors_ = outputTensors[0]; - return collective( + return collective( inputTensor, - outputFlattened, + outputTensors_, [&](at::Tensor& input, - at::Tensor& output, + const std::vector& outputs, ccl::allgather_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue()), - attr); - return; + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + + std::vector recvCounts(outputs.size(), 0); + std::transform(outputs.begin(), outputs.end(), recvCounts.begin(), + [](const at::Tensor& t) { + return t.numel(); + }); + + TORCH_CHECK((size_t)input.numel() == recvCounts[rank], "allgather: send and recv count doesn't match"); + std::vector recvBufs(outputs.size(), nullptr); + std::transform(outputs.begin(), outputs.end(), recvBufs.begin(), + [](const at::Tensor& t) { + return t.data_ptr(); + }); + + ccl::allgatherv(input.data_ptr(), + (size_t) input.numel(), + recvBufs, + recvCounts, + cclDatatypes.at(input.scalar_type()), + comm, + ccl::create_stream(stream.queue())); + return; }, [](at::xpu::XPUStream&, c10::intrusive_ptr& work) {}, @@ -1137,19 +1090,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( } }, OpType::ALLGATHER); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } + } c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( @@ -1159,16 +1100,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( check_xpu_single_tensor(input_tensor); check_xpu_single_tensor(output_tensor); - if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); - } - - if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, - "output tensor size must be equal to world_size times input tensor size"); - } + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, input_tensor.numel() * size_ != output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); return collective( input_tensor, @@ -1230,9 +1165,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( check_xpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); - TORCH_CHECK( - !isFloat8Type(outputTensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); bool same_size = check_same_size(inputTensors_); if (same_size) { @@ -1248,7 +1180,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(), @@ -1295,22 +1227,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); - } - - if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, - "input tensor must be the same size as output size times world size"); - } - - // @lint-ignore CLANGTIDY - const auto& tensor = outputTensor; - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, inputTensor.numel() != outputTensor.numel() * size_, + "input tensor size must be equal to world_size times output tensor size"); return collective( inputTensor, @@ -1322,7 +1242,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(), @@ -1341,9 +1261,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { - TORCH_CHECK( - !isFloat8Type(inputs.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( inputs, outputs, @@ -1354,7 +1271,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(),