Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 4, 2024
1 parent 076db36 commit 8d739ac
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 455 deletions.
1 change: 0 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,6 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
]

libtorch_python_xpu_sources = [
"torch/csrc/xpu/xccl.cpp",
"torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp",
"torch/csrc/xpu/Event.cpp",
"torch/csrc/xpu/Module.cpp",
Expand Down
8 changes: 4 additions & 4 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1014,10 +1014,10 @@ elseif(USE_CUDA)
endif()

if(USE_XPU)
if(USE_XCCL)
list(APPEND Caffe2_XPU_SRCS
${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp)
endif()
# if(USE_XCCL)
# list(APPEND Caffe2_XPU_SRCS
# ${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp)
# endif()
add_library(torch_xpu ${Caffe2_XPU_SRCS})
torch_compile_options(torch_xpu) # see cmake/public/utils.cmake
target_compile_definitions(torch_xpu PRIVATE USE_XPU)
Expand Down
61 changes: 40 additions & 21 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <unordered_set>
#include <utility>

#include <ATen/detail/FunctionTraits.h>
#include <c10/core/DeviceType.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h>
Expand Down Expand Up @@ -45,6 +46,36 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kBool, ccl::datatype::uint8},
};

XCCL_KVS kvs;
std::mutex kvs_mutex;

XCCL_KVS get_kvs(int rank, c10d::Store& store) {
std::lock_guard<std::mutex> lock(kvs_mutex);
if (kvs)
return kvs;
std::string storeKey = "ccl_kvs";

// Rank 0 broadcast the bootstrap network information to other ranks
if (rank == 0) {
kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = kvs->get_address();
auto ccl_kvs_addr =
std::vector<uint8_t>(main_addr.begin(), main_addr.end());
store.set(storeKey, ccl_kvs_addr);
} else {
auto ccl_kvs_addr = store.get(storeKey);
if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) {
throw std::runtime_error("Unexpected ccl kvs addr from the store\n");
}
ccl::kvs::address_type main_addr;
std::copy_n(
ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin());
kvs = ccl::create_kvs(main_addr);
}

return kvs;
}

void check_xpu_single_tensor(const at::Tensor& tensor) {
if (!tensor.is_xpu() || tensor.is_sparse()) {
C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense");
Expand Down Expand Up @@ -89,11 +120,6 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) {
return true;
}

c10::intrusive_ptr<c10::ivalue::Future> ProcessGroupXCCL::WorkXCCL::
getFuture() {
return future_;
}

void ProcessGroupXCCL::WorkXCCL::synchronize() {
auto currentStream = at::xpu::getCurrentXPUStream(device_.index());
// Block the current stream on the XCCL stream
Expand All @@ -107,12 +133,6 @@ c10::intrusive_ptr<Backend> ProcessGroupXCCL::createProcessGroupXCCL(
return c10::make_intrusive<ProcessGroupXCCL>(store, rank, size);
}

ProcessGroupXCCL::ProcessGroupXCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size)
: Backend(rank, size), store_(store) {}

ProcessGroupXCCL::~ProcessGroupXCCL() = default;

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> ProcessGroupXCCL::initWork(
Expand Down Expand Up @@ -148,7 +168,7 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(

std::shared_ptr<xcclComm_t> XCCLComm;

XCCL_KVS kvs = get_kvs(rank_, store_);
XCCL_KVS kvs = get_kvs(rank_, *store_);

int numRanks, rank;
numRanks = getSize();
Expand All @@ -157,7 +177,7 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
ccl::vector_class<ccl::pair_class<int, ccl::device>> devs_rank;
c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream = impl.getStream(device);
auto q = at::xpu::XPUStream(stream).queue();
sycl::queue& q = c10::xpu::XPUStream(stream).queue();
auto ctx = ccl::create_context(q.get_context());
devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));
XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs);
Expand Down Expand Up @@ -208,20 +228,20 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(

c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> work;

work =initWork(device, rank_, op_type);
// work = make_work_ccl<WorkXCCL>(
// inputs, outputs, fn, xcclComm_t, attr, rank_, op_type);
// work->events_.emplace_back(fn);
work = initWork(device, rank_, opType);

work->outputs_ =
std::make_shared<std::vector<at::Tensor>>(std::move(outputs));
c10::xpu::XPUCachingAllocator::recordStream(
input.storage().data_ptr(), xcclStream);

auto ccl_stream = ccl::create_stream(at::xpu::XPUStream(xcclStream).queue());
auto ccl_stream = ccl::create_stream(xcclStream.queue());
fn(input, output, attr, comm, ccl_stream);

work->xcclEndEvent_->record(xcclStream);
c10::MultiStreamGuard streamGuard(xcclStream);

std::vector<c10::Stream> streams = {xcclStream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()), devices);
Expand Down Expand Up @@ -266,13 +286,12 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
xcclComm_t comm,
ccl::stream& stream) {
ccl::event ret_evt;
ccl::datatype datatype = getXcclDataType(input.scalar_type());
ret_evt = ccl::allreduce(
input.data_ptr(),
output.data_ptr(),
(size_t)input.numel(),
getXcclDataType(input.scalar_type()),
xcclOp.at(opts.reduceOp),
xcclOps.at(opts.reduceOp),
comm,
stream,
attr);
Expand Down
40 changes: 36 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#endif

// #ifdef USE_C10D_XCCL

#include <oneapi/ccl.hpp>
#include <torch/csrc/xpu/xccl.h>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/xpu/Event.h>
#include <torch/csrc/xpu/Stream.h>
// #include <torch/csrc/xpu/xccl.h>
#include <exception>
#include <memory>
#include <vector>
Expand All @@ -24,14 +26,17 @@
#include <thread>
#include <unordered_map>

#include <c10/core/StreamGuard.h>
#include <c10/xpu/XPUCachingAllocator.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>

namespace c10d {

using xcclComm_t = ccl::communicator;
using XCCL_KVS = ccl::shared_ptr_class<ccl::kvs>;
constexpr const char* XCCL_BACKEND_NAME = "xccl";
using namespace torch::xpu::xccl;
// using namespace torch::xpu::xccl;

class TORCH_XPU_API ProcessGroupXCCL : public Backend {
public:
Expand Down Expand Up @@ -113,6 +118,33 @@ class TORCH_XPU_API ProcessGroupXCCL : public Backend {
return std::string(XCCL_BACKEND_NAME);
}

std::shared_ptr<xcclComm_t> getXCCLComm(
const std::string& deviceKey,
at::Device& device);

virtual c10::intrusive_ptr<ProcessGroupXCCL::WorkXCCL> initWork(
at::Device& device,
int rank,
OpType opType,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {});

template <typename Fn>
c10::intrusive_ptr<Work> collective(
at::Tensor& input,
at::Tensor& output,
Fn fn,
OpType opType);

template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
at::Tensor& input,
at::Tensor& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType);

c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
Expand Down
Loading

0 comments on commit 8d739ac

Please sign in to comment.