Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 12, 2024
1 parent e85c268 commit 7488dbd
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,23 @@ namespace {

// wait nonblocking implement
AutoXcclGroup::AutoXcclGroup() {
<<<<<<< HEAD
ccl::group_start();
}

AutoXcclGroup::AutoXcclGroup(std::shared_ptr<xcclComm_t> comm) {
comm_ = std::move(comm);
ccl::group_start();
}

AutoXcclGroup::~AutoXcclGroup() noexcept(false) {
=======
comm_ = nullptr;
ccl::group_start();
}

AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
>>>>>>> e85c26816e3dacf7244cc0d4f5abe1914f79fe66
ccl::group_end();
}

Expand Down Expand Up @@ -175,6 +187,10 @@ constexpr int64_t kSynchronizeBusyWaitMillis = 10;
// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no effect.
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;

// Before implementing send/recv, the xcclActiveGroupCounter_ variable has no
// effect.
thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0;

ProcessGroupXCCL::WorkXCCL::WorkXCCL(
at::Device& device,
int rank,
Expand Down Expand Up @@ -435,7 +451,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collectiveCoalesced(
std::vector<at::Tensor>& outputs,
Fn fn,
OpType opType) {

using traits = function_traits<Fn>;
using attr_t = typename traits::template arg<2>::type;
attr_t attr = ccl::create_operation_attr<attr_t>();
Expand Down Expand Up @@ -511,8 +526,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(

c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) {
const AllreduceCoalescedOptions& opts) {
check_xpu_tensors_same_device(tensors);
TORCH_CHECK(
!isFloat8Type(tensors.back().scalar_type()),
Expand All @@ -523,6 +537,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce_coalesced(
tensors,
[&](at::Tensor& input,
at::Tensor& output,
ccl::allreduce_attr attr,
xcclComm_t& comm,
ccl::stream& stream) {
auto xcclDataType = getXcclDataType(input.scalar_type());
Expand Down

0 comments on commit 7488dbd

Please sign in to comment.