diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 32030f45e73ae3..e973ce110ab0bc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -4,6 +4,7 @@ #include #ifdef USE_C10D_XCCL +#include #include #include #include @@ -459,6 +460,8 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( work->outputs_ = std::make_shared>(outputs); + at::xpu::OptionalXPUGuard gpuGuard(device); + pre(stream, work); for (const auto& input : inputs) { @@ -470,7 +473,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( post(stream, work); - work->xcclEndEvent_->record(stream); + if (!coalescing_state_) { + work->xcclEndEvent_->record(stream); + } std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); @@ -550,6 +555,8 @@ c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( work->outputs_ = std::make_shared>(outputs); + at::xpu::OptionalXPUGuard gpuGuard(device); + { AutoXcclGroup xccl_group_guard; for (const auto i : c10::irange(inputs.size())) {