diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index ce482a97952c34..ea5220e3dac77b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -29,11 +29,23 @@ namespace { // wait nonblocking implement AutoXcclGroup::AutoXcclGroup() { +<<<<<<< HEAD + ccl::group_start(); +} + +AutoXcclGroup::AutoXcclGroup(std::shared_ptr comm) { + comm_ = std::move(comm); + ccl::group_start(); +} + +AutoXcclGroup::~AutoXcclGroup() noexcept(false) { +======= comm_ = nullptr; ccl::group_start(); } AutoNcclGroup::~AutoNcclGroup() noexcept(false) { +>>>>>>> e85c26816e3dacf7244cc0d4f5abe1914f79fe66 ccl::group_end(); } @@ -175,6 +187,10 @@ constexpr int64_t kSynchronizeBusyWaitMillis = 10; // Before implementing send/recv, the xcclActiveGroupCounter_ variable has no effect. thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; +// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no +// effect. +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; + ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, @@ -435,7 +451,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( std::vector& outputs, Fn fn, OpType opType) { - using traits = function_traits; using attr_t = typename traits::template arg<2>::type; attr_t attr = ccl::create_operation_attr(); @@ -511,8 +526,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( std::vector& tensors, - const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) { + const AllreduceCoalescedOptions& opts) { check_xpu_tensors_same_device(tensors); TORCH_CHECK( !isFloat8Type(tensors.back().scalar_type()), @@ -523,6 +537,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( tensors, [&](at::Tensor& input, at::Tensor& output, + ccl::allreduce_attr attr, xcclComm_t& comm, ccl::stream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type());