diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index b51d299b47d8d..fd02226a1dd77 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -36,7 +36,6 @@ std::map xcclOps = { std::map xcclDatatypes = { {at::kByte, ccl::datatype::uint8}, {at::kChar, ccl::datatype::int8}, - {at::kShort, ccl::datatype::int16}, {at::kInt, ccl::datatype::int32}, {at::kLong, ccl::datatype::int64}, {at::kHalf, ccl::datatype::float16}, @@ -148,9 +147,9 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) : Work(w.rank_, w.opType_), device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), blockingWait_(w.blockingWait_), - workStartTime_(w.workStartTime_), - xcclEndEvent_(w.xcclEndEvent_) {} + workStartTime_(w.workStartTime_) {} ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; @@ -174,7 +173,8 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { try { TORCH_CHECK(flag = ret.test()); } catch (...) { - finishAWorkXCCLError(std::current_exception()); + future_->setError(std::current_exception()); + finish(std::current_exception()); return true; } if (!flag) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 6d946acbea804..6d4cc5097ebbc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -124,10 +124,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector rets; private: - void finishAWorkXCCLError(std::exception_ptr eptr) { - future_->setError(eptr); - finish(eptr); - } void synchronizeInternal(std::chrono::milliseconds timeout); std::shared_ptr> outputs_; c10::intrusive_ptr future_;