From 22e80041db624594c7f72b527ec22bed3d7edb02 Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 11 Sep 2024 05:34:46 +0000 Subject: [PATCH] Support broadcast --- .../distributed/c10d/ProcessGroupXCCL.cpp | 36 +++++++++++++++++++ .../distributed/c10d/ProcessGroupXCCL.hpp | 4 +-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index a164f5c4493101..dfd168b9eeff2d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -462,6 +462,42 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( OpType::COALESCED); } +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); + auto tensor = tensors.back(); + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } + check_xpu_single_tensor(tensor); + + const auto root = opts.rootRank + opts.rootTensor; + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::broadcast_attr attr, + xcclComm_t& comm, + ccl::stream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::event ret_evt; + ret_evt = ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + stream, + attr); + return ret_evt; + }, + OpType::BROADCAST); +} + } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 28bbc9ae012db1..3a60471fa8d05b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -157,9 +157,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr broadcast( std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); - } + const BroadcastOptions& opts = BroadcastOptions()) override; c10::intrusive_ptr allgather( std::vector>& outputTensors,