diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 7ed4502743299e..8867904d92e24c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -43,36 +43,6 @@ std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - bool computeLengthsAndCheckAndGetFlat( const std::vector& tensors, std::vector& lengths, @@ -120,44 +90,32 @@ bool check_same_size(const std::vector& input_tensors) { return true; } -void check_xpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); +void check_xpu_single_tensor(const at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } - // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } } int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { - C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); - } + TORCH_CHECK_WITH(ValueError, tensors.size() == 0, "Tensor list must be nonempty"); const auto& first = tensors.front(); int64_t total_numel = 0; for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + if (!t.is_xpu() || t.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense and non-complex"); } if (t.scalar_type() != first.scalar_type()) { C10_THROW_ERROR(TypeError, "Tensors must have identical type"); } - if (!t.is_non_overlapping_and_dense()) { - C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); - } + // todo: zl_debug do we need it? +// if (!t.is_non_overlapping_and_dense()) { +// C10_THROW_ERROR(ValueError, "Tensors must be non-overlapping and dense"); +// } TORCH_CHECK_WITH( ValueError, t.get_device() == tensors[0].get_device(), @@ -168,7 +126,11 @@ int64_t check_xpu_tensors_same_device(const std::vector& tensors) { return total_numel; } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType(at::ScalarType type, bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(tensor.scalar_type()) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -204,18 +166,6 @@ void syncStream( xcclEvent.block(xcclStream); } -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { - switch (reduceOp) { - case ReduceOp::SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - } // namespace static std::mutex xcclCommDevIdxMapMutex; @@ -400,8 +350,8 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( // ccl::group_start(); // } - xcclStreams_.emplace(deviceKey, std::move(stream)); - xcclEvents_.emplace(deviceKey, at::xpu::XPUEvent()); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; } @@ -877,7 +827,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::allreduce( input.data_ptr(), @@ -896,31 +846,38 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "all_reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } - check_xpu_single_tensor(tensor); - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); - return allreduce_impl(tensor, opts); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue()), + attr); + return; + }, + OpType::ALLREDUCE); } c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts) { check_xpu_tensors_same_device(tensors); - TORCH_CHECK( - !isFloat8Type(tensors.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( tensors, @@ -930,7 +887,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::allreduce( input.data_ptr(), @@ -951,9 +908,6 @@ c10::intrusive_ptr ProcessGroupXCCL::broadcast( const BroadcastOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - if (tensor.is_complex()) { - tensor = at::view_as_real(tensor); - } check_xpu_single_tensor(tensor); const auto root = opts.rootRank + opts.rootTensor; @@ -984,11 +938,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _broadcast_oop must have the same number of elements "); + const auto root = opts.rootRank + opts.rootTensor; return collective( inputTensor, @@ -1016,16 +968,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( std::vector& tensors, const ReduceOptions& opts) { TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - if (tensor.is_complex()) { - TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), - "reduce does not support", - opts.reduceOp, - "on complex tensors"); - tensor = at::view_as_real(tensor); - } + auto tensor = tensors[0]; check_xpu_single_tensor(tensor); return collective( @@ -1037,8 +980,10 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + const auto xcclStream = getXcclStream(stream.queue); + ccl::reduce( input.data_ptr(), output.data_ptr(), @@ -1057,11 +1002,9 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _reduce_oop must have the same number of elements "); - } + TORCH_CHECK_WITH(ValueError, outputTensor.numel() != inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( inputTensor, outputTensor, @@ -1071,7 +1014,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( xcclComm_t& comm, at::xpu::XPUStream& stream) { const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type()); + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce( input.data_ptr(), @@ -1092,37 +1035,47 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( std::vector& inputTensors, const AllgatherOptions& opts) { TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); + auto inputTensor = inputTensors[0]; check_xpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); - - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); + std::vector& outputTensors_ = outputTensors[0]; - return collective( + return collective( inputTensor, - outputFlattened, + outputTensors_, [&](at::Tensor& input, - at::Tensor& output, + const std::vector& outputs, ccl::allgather_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue()), - attr); - return; + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + + std::vector recvCounts(outputs.size(), 0); + std::transform(outputs.begin(), outputs.end(), recvCounts.begin(), + [](const at::Tensor& t) { + return t.numel(); + }); + + TORCH_CHECK((size_t)input.numel() == recvCounts[rank], "allgather: send and recv count doesn't match"); + std::vector recvBufs(outputs.size(), nullptr); + std::transform(outputs.begin(), outputs.end(), recvBufs.begin(), + [](const at::Tensor& t) { + return t.data_ptr(); + }); + + ccl::allgatherv(input.data_ptr(), + (size_t) input.numel(), + recvBufs, + recvCounts, + cclDatatypes.at(input.scalar_type()), + comm, + ccl::create_stream(stream.queue())); + return; }, [](at::xpu::XPUStream&, c10::intrusive_ptr& work) {}, @@ -1137,19 +1090,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allgather( } }, OpType::ALLGATHER); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } + } c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( @@ -1159,16 +1100,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( check_xpu_single_tensor(input_tensor); check_xpu_single_tensor(output_tensor); - if (input_tensor.dtype() != output_tensor.dtype()) { - C10_THROW_ERROR( - TypeError, "output tensor must have the same type as input tensor"); - } - - if (input_tensor.numel() * size_ != output_tensor.numel()) { - C10_THROW_ERROR( - ValueError, - "output tensor size must be equal to world_size times input tensor size"); - } + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, input_tensor.numel() * size_ != output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); return collective( input_tensor, @@ -1230,9 +1165,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( check_xpu_single_tensor(outputTensor); // @lint-ignore CLANGTIDY auto inputTensors_ = inputTensors.back(); - TORCH_CHECK( - !isFloat8Type(outputTensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); bool same_size = check_same_size(inputTensors_); if (same_size) { @@ -1248,7 +1180,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(), @@ -1295,22 +1227,10 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts) { - if (inputTensor.dtype() != outputTensor.dtype()) { - C10_THROW_ERROR( - TypeError, "input tensor must be the same type as the output tensor."); - } - - if (inputTensor.numel() != outputTensor.numel() * size_) { - C10_THROW_ERROR( - ValueError, - "input tensor must be the same size as output size times world size"); - } - - // @lint-ignore CLANGTIDY - const auto& tensor = outputTensor; - TORCH_CHECK( - !isFloat8Type(tensor.scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK_WITH(TypeError, input_tensor.dtype() != output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH(ValueError, inputTensor.numel() != outputTensor.numel() * size_, + "input tensor size must be equal to world_size times output tensor size"); return collective( inputTensor, @@ -1322,7 +1242,7 @@ c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(), @@ -1341,9 +1261,6 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( std::vector& outputs, std::vector& inputs, const ReduceScatterOptions& opts) { - TORCH_CHECK( - !isFloat8Type(inputs.back().scalar_type()), - "Float8 dtypes are not currenlty supported for XCCL reductions"); return collectiveCoalesced( inputs, outputs, @@ -1354,7 +1271,7 @@ c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( at::xpu::XPUStream& stream) { c10::xpu::XPUCachingAllocator::recordStream( output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::reduce_scatter( input.data_ptr(),