diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 55339880a82a3..2119dd1932800 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1014,10 +1014,6 @@ elseif(USE_CUDA) endif() if(USE_XPU) - # if(USE_XCCL) - # list(APPEND Caffe2_XPU_SRCS - # ${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp) - # endif() add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake target_compile_definitions(torch_xpu PRIVATE USE_XPU) @@ -1373,7 +1369,14 @@ if(USE_DISTRIBUTED) endif() endif() if(USE_C10D_XCCL) + # if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + # set_source_files_properties( + # ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp + # PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_SYCL") + # target_sources(torch_xpu PRIVATE ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp) + # endif() target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC CCL_ENABLE_SYCL) endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index e21be88ef83d1..cabdb9f61433b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -3,7 +3,7 @@ #include #include -// #ifdef USE_C10D_XCCL +#ifdef USE_C10D_XCCL #include #include #include @@ -174,20 +174,16 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( numRanks = getSize(); rank = getRank(); - ccl::vector_class> devs_rank; c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = impl.getStream(device); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); - // const sycl::context& sycl_ctx = q.get_context(); - // sycl::context sycl_ctx = q.get_context(); - // ccl::generic_context_type ccl_ctx(sycl_ctx); - // auto ctx = ccl::create_context(ccl_ctx.get()); - // auto ctx = ccl::create_context(q.get_context()); - // devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - // XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs); - XCCLComm = std::make_shared(ccl::create_communicator(numRanks, rank, kvs)); + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); + XCCLComm = std::make_shared(std::move(comms[0])); { std::lock_guard lock(mutex_); @@ -242,11 +238,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::xpu::XPUCachingAllocator::recordStream( input.storage().data_ptr(), stream); - // auto ccl_stream = ccl::create_stream(stream.queue()); - auto ccl_stream = ccl::create_stream(); + auto ccl_stream = ccl::create_stream(stream.queue()); + // auto ccl_stream = ccl::create_stream(); fn(input, output, attr, *comm, ccl_stream); - // fn(input, output, attr, comm, ccl_stream); work->xcclEndEvent_->record(stream); @@ -312,4 +307,4 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( } // namespace c10d -// #endif // USE_C10D_XCCL +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 9ad20797afcb6..f8b9d15bd6548 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -7,12 +7,11 @@ #include #endif -// #ifdef USE_C10D_XCCL +#ifdef USE_C10D_XCCL #include #include #include #include -// #include #include #include #include @@ -36,7 +35,6 @@ namespace c10d { using xcclComm_t = ccl::communicator; using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; -// using namespace torch::xpu::xccl; class TORCH_API ProcessGroupXCCL : public Backend { public: @@ -47,30 +45,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { int rank, OpType opType, const std::optional>& inputs = std::nullopt); - // WorkXCCL( - // std::vector> outputTensors, - // int rank = -1, - // OpType opType = OpType::UNKNOWN, - // const c10::optional>& inputTensors = - // c10::nullopt) - // : Work(rank, opType), outputTensors_(std::move(outputTensors)) {} WorkXCCL(const WorkXCCL& w); - // ~WorkXCCL() override { - // // Ensures all events are properly handled before destruction - // for (auto& event : events_) { - // event.wait(); - // } - // } ~WorkXCCL() override; bool isCompleted() override { TORCH_CHECK( false, "ProcessGroupXCCL::WorkXCCL::isCompleted not implemented"); - // for (auto& event : events_) { - // if (!event.test()) { - // return false; - // } - // } - // return true; } bool isSuccess() const override { @@ -97,9 +76,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: at::Device device_; std::shared_ptr xcclEndEvent_; - // std::vector events_; - // std::shared_ptr xcclComm_; - // const std::vector> outputTensors_; private: std::shared_ptr> outputs_; c10::intrusive_ptr future_; @@ -110,7 +86,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { const c10::intrusive_ptr& store, int rank, int size) - : store_(store), Backend(rank, size) {} + : Backend(rank, size), store_(store) {} ~ProcessGroupXCCL() override; @@ -168,4 +144,4 @@ class TORCH_API ProcessGroupXCCL : public Backend { } // namespace c10d -// #endif // USE_C10D_XCCL +#endif // USE_C10D_XCCL