diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index fd02226a1dd772..8f689ec80eb12a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -167,14 +167,19 @@ bool ProcessGroupXCCL::WorkXCCL::checkTimeout( return true; } +void ProcessGroupXCCL::WorkXCCL::finishWorkXcclError( + const std::exception_ptr& eptr) { + future_->setError(eptr); + finish(eptr); +} + bool ProcessGroupXCCL::WorkXCCL::isCompleted() { for (auto& ret : rets) { bool flag; try { TORCH_CHECK(flag = ret.test()); } catch (...) { - future_->setError(std::current_exception()); - finish(std::current_exception()); + finishWorkXcclError(std::current_exception()); return true; } if (!flag) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 6d4cc5097ebbc5..37e36047a63c16 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -88,6 +88,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { rets.push_back(std::move(result)); } + void finishWorkXcclError(const std::exception_ptr& eptr); + bool isCompleted() override; bool isSuccess() const override {