Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 6, 2024
1 parent 4f73180 commit 7c2f018
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 45 deletions.
11 changes: 7 additions & 4 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1014,10 +1014,6 @@ elseif(USE_CUDA)
endif()

if(USE_XPU)
# if(USE_XCCL)
# list(APPEND Caffe2_XPU_SRCS
# ${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp)
# endif()
add_library(torch_xpu ${Caffe2_XPU_SRCS})
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
target_compile_definitions(torch_xpu PRIVATE USE_XPU)
Expand Down Expand Up @@ -1373,7 +1369,14 @@ if(USE_DISTRIBUTED)
endif()
endif()
if(USE_C10D_XCCL)
# if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
# set_source_files_properties(
# ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp
# PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_SYCL")
# target_sources(torch_xpu PRIVATE ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp)
# endif()
target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
target_compile_definitions(torch_xpu PUBLIC CCL_ENABLE_SYCL)
endif()
if(USE_MPI AND USE_C10D_MPI)
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
Expand Down
23 changes: 9 additions & 14 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <mutex>
#include <sstream>

// #ifdef USE_C10D_XCCL
#ifdef USE_C10D_XCCL
#include <exception>
#include <map>
#include <stdexcept>
Expand Down Expand Up @@ -174,20 +174,16 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
numRanks = getSize();
rank = getRank();

ccl::vector_class<ccl::pair_class<int, ccl::device>> devs_rank;
c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream = impl.getStream(device);
sycl::queue& q = c10::xpu::XPUStream(stream).queue();
// const sycl::context& sycl_ctx = q.get_context();
// sycl::context sycl_ctx = q.get_context();
// ccl::generic_context_type<ccl::cl_backend_type::dpcpp_sycl_l0> ccl_ctx(sycl_ctx);
// auto ctx = ccl::create_context(ccl_ctx.get());

// auto ctx = ccl::create_context(q.get_context());
// devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));
// XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs);
XCCLComm = std::make_shared<xcclComm_t>(ccl::create_communicator(numRanks, rank, kvs));
auto ctx = ccl::create_context(q.get_context());
ccl::vector_class<ccl::pair_class<int, ccl::device>> devs_rank;
devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));

auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs);
XCCLComm = std::make_shared<xcclComm_t>(std::move(comms[0]));

{
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -242,11 +238,10 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
c10::xpu::XPUCachingAllocator::recordStream(
input.storage().data_ptr(), stream);

// auto ccl_stream = ccl::create_stream(stream.queue());
auto ccl_stream = ccl::create_stream();
auto ccl_stream = ccl::create_stream(stream.queue());
// auto ccl_stream = ccl::create_stream();

fn(input, output, attr, *comm, ccl_stream);
// fn(input, output, attr, comm, ccl_stream);

work->xcclEndEvent_->record(stream);

Expand Down Expand Up @@ -312,4 +307,4 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(

} // namespace c10d

// #endif // USE_C10D_XCCL
#endif // USE_C10D_XCCL
30 changes: 3 additions & 27 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
#include <unistd.h>
#endif

// #ifdef USE_C10D_XCCL
#ifdef USE_C10D_XCCL
#include <oneapi/ccl.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/xpu/Event.h>
#include <torch/csrc/xpu/Stream.h>
// #include <torch/csrc/xpu/xccl.h>
#include <exception>
#include <memory>
#include <vector>
Expand All @@ -36,7 +35,6 @@ namespace c10d {
using xcclComm_t = ccl::communicator;
using XCCL_KVS = ccl::shared_ptr_class<ccl::kvs>;
constexpr const char* XCCL_BACKEND_NAME = "xccl";
// using namespace torch::xpu::xccl;

class TORCH_API ProcessGroupXCCL : public Backend {
public:
Expand All @@ -47,30 +45,11 @@ class TORCH_API ProcessGroupXCCL : public Backend {
int rank,
OpType opType,
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt);
// WorkXCCL(
// std::vector<std::vector<at::Tensor>> outputTensors,
// int rank = -1,
// OpType opType = OpType::UNKNOWN,
// const c10::optional<std::vector<at::Tensor>>& inputTensors =
// c10::nullopt)
// : Work(rank, opType), outputTensors_(std::move(outputTensors)) {}
WorkXCCL(const WorkXCCL& w);
// ~WorkXCCL() override {
// // Ensures all events are properly handled before destruction
// for (auto& event : events_) {
// event.wait();
// }
// }
~WorkXCCL() override;
bool isCompleted() override {
TORCH_CHECK(
false, "ProcessGroupXCCL::WorkXCCL::isCompleted not implemented");
// for (auto& event : events_) {
// if (!event.test()) {
// return false;
// }
// }
// return true;
}

bool isSuccess() const override {
Expand All @@ -97,9 +76,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
protected:
at::Device device_;
std::shared_ptr<at::xpu::XPUEvent> xcclEndEvent_;
// std::vector<ccl::event> events_;
// std::shared_ptr<xcclComm_t> xcclComm_;
// const std::vector<std::vector<at::Tensor>> outputTensors_;
private:
std::shared_ptr<std::vector<at::Tensor>> outputs_;
c10::intrusive_ptr<at::ivalue::Future> future_;
Expand All @@ -110,7 +86,7 @@ class TORCH_API ProcessGroupXCCL : public Backend {
const c10::intrusive_ptr<Store>& store,
int rank,
int size)
: store_(store), Backend(rank, size) {}
: Backend(rank, size), store_(store) {}

~ProcessGroupXCCL() override;

Expand Down Expand Up @@ -168,4 +144,4 @@ class TORCH_API ProcessGroupXCCL : public Backend {

} // namespace c10d

// #endif // USE_C10D_XCCL
#endif // USE_C10D_XCCL

0 comments on commit 7c2f018

Please sign in to comment.