diff --git a/build_variables.bzl b/build_variables.bzl index b903a55b17439..cff70d00320b0 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -785,7 +785,6 @@ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ ] libtorch_python_xpu_sources = [ - "torch/csrc/xpu/xccl.cpp", "torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp", "torch/csrc/xpu/Event.cpp", "torch/csrc/xpu/Module.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 01d280cb3fc7c..55339880a82a3 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1014,10 +1014,10 @@ elseif(USE_CUDA) endif() if(USE_XPU) - if(USE_XCCL) - list(APPEND Caffe2_XPU_SRCS - ${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp) - endif() + # if(USE_XCCL) + # list(APPEND Caffe2_XPU_SRCS + # ${TORCH_SRC_DIR}/csrc/xpu/xccl.cpp) + # endif() add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake target_compile_definitions(torch_xpu PRIVATE USE_XPU) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 8be7c6451fcdd..ffd566f10f854 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -45,6 +46,36 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; +XCCL_KVS kvs; +std::mutex kvs_mutex; + +XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "ccl_kvs"; + + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + + return kvs; +} + void check_xpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); @@ -89,11 +120,6 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } -c10::intrusive_ptr ProcessGroupXCCL::WorkXCCL:: - getFuture() { - return future_; -} - void ProcessGroupXCCL::WorkXCCL::synchronize() { auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); // Block the current stream on the XCCL stream @@ -107,12 +133,6 @@ c10::intrusive_ptr ProcessGroupXCCL::createProcessGroupXCCL( return c10::make_intrusive(store, rank, size); } -ProcessGroupXCCL::ProcessGroupXCCL( - const c10::intrusive_ptr& store, - int rank, - int size) - : Backend(rank, size), store_(store) {} - ProcessGroupXCCL::~ProcessGroupXCCL() = default; c10::intrusive_ptr ProcessGroupXCCL::initWork( @@ -148,7 +168,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( std::shared_ptr XCCLComm; - XCCL_KVS kvs = get_kvs(rank_, store_); + XCCL_KVS kvs = get_kvs(rank_, *store_); int numRanks, rank; numRanks = getSize(); @@ -157,7 +177,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = impl.getStream(device); - auto q = at::xpu::XPUStream(stream).queue(); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs); @@ -208,20 +228,20 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::intrusive_ptr work; - work =initWork(device, rank_, op_type); - // work = make_work_ccl( - // inputs, outputs, fn, xcclComm_t, attr, rank_, op_type); - // work->events_.emplace_back(fn); + work = initWork(device, rank_, opType); + work->outputs_ = std::make_shared>(std::move(outputs)); c10::xpu::XPUCachingAllocator::recordStream( input.storage().data_ptr(), xcclStream); - auto ccl_stream = ccl::create_stream(at::xpu::XPUStream(xcclStream).queue()); + auto ccl_stream = ccl::create_stream(xcclStream.queue()); fn(input, output, attr, comm, ccl_stream); work->xcclEndEvent_->record(xcclStream); - c10::MultiStreamGuard streamGuard(xcclStream); + + std::vector streams = {xcclStream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get()), devices); @@ -266,13 +286,12 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclComm_t comm, ccl::stream& stream) { ccl::event ret_evt; - ccl::datatype datatype = getXcclDataType(input.scalar_type()); ret_evt = ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), getXcclDataType(input.scalar_type()), - xcclOp.at(opts.reduceOp), + xcclOps.at(opts.reduceOp), comm, stream, attr); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 01a5966b81106..b43403f52f31a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -8,9 +8,11 @@ #endif // #ifdef USE_C10D_XCCL - #include -#include +#include +#include +#include +// #include #include #include #include @@ -24,14 +26,17 @@ #include #include +#include +#include #include #include #include - namespace c10d { +using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; -using namespace torch::xpu::xccl; +// using namespace torch::xpu::xccl; class TORCH_XPU_API ProcessGroupXCCL : public Backend { public: @@ -113,6 +118,33 @@ class TORCH_XPU_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType); + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; diff --git a/torch/csrc/xpu/xccl.cpp b/torch/csrc/xpu/xccl.cpp deleted file mode 100644 index 6224b19254dbf..0000000000000 --- a/torch/csrc/xpu/xccl.cpp +++ /dev/null @@ -1,348 +0,0 @@ -#include - -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - - -ccl::datatype to_xccl_data_type(c10::ScalarType type) { - switch (type) { - case at::kFloat: - return ccl::datatype::float32; - case at::kHalf: - return ccl::datatype::float16; - case at::kDouble: - return ccl::datatype::float64; - case at::kLong: - return ccl::datatype::int64; - case at::kInt: - return ccl::datatype::int32; - case at::kChar: - return ccl::datatype::int8; - case at::kByte: - return ccl::datatype::uint8; - case at::kBool: - return ccl::datatype::uint8; - case at::kBFloat16: - return ccl::datatype::bfloat16; - default: - TORCH_CHECK(false, "Unconvertible XCCL type ", type); - } -} - -ccl::datatype to_xccl_data_type(const at::Tensor& t) { - if (!t.is_xpu()) { - TORCH_CHECK( - false, - "XCCL only supports XPU tensors, but got a tensor on ", - t.device()); - } - return to_xccl_data_type(t.scalar_type()); -} - -ccl::reduction to_xccl_red_op(int var) { - return (ccl::reduction)(var); -} - -namespace torch::xpu::xccl { - -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "ccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), - ccl::kvs::address_max_size, - main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - -using namespace at; - -namespace detail { - -// void xcclCommInitAll(xcclComm_t* newcomm, int nranks, ncclUniqueId commId, -// int myrank) { -// for(int i = 0; i < nranks; i++) { -// newcomm[i] = ccl::create_communicator(nranks, i, get_kvs_addr) -// } -// c10::Stream dpcpp_stream = impl.getStream(devices[0]); -// ccl::vector_class> devs_rank; -// newcomm = ccl::create_communicators(nranks, devs_rank, ctx, ) -// } - -// struct XcclCommList { -// std::unique_ptr comms; -// int ndevices; -// XcclCommList(const std::vector& devices) -// : comms(new xcclComm_t[devices.size()]), ndevices(devices.size()) { -// xcclCommInitAll( -// to_xccl_comm(comms.get()), devices.size(), devices.data()); -// } -// NcclCommList(NcclCommList&& foo) = default; -// ~NcclCommList() { -// if (comms) { -// for (const auto i : c10::irange(ndevices)) { -// comm_destroy(comms[i]); -// } -// } -// } -// ArrayRef ref() const { -// return ArrayRef(comms.get(), ndevices); -// } -// }; - -// using device_list = std::vector; -// // accesses to this object have to be guarded by THC's CudaFreeMutex -// std::unordered_map> _communicators; -// static std::unordered_map> -// _communicators; - -// ArrayRef get_communicators(TensorList inputs) { -// static auto get_device = [](const at::Tensor& t) -> int { -// return t.get_device(); -// }; -// device_list devices = fmap(inputs, get_device); -// auto it = _communicators.find(devices); -// if (it == _communicators.end()) { -// it = _communicators.emplace(devices, devices).first; -// } -// return it->second; -// } - -static inline void check_tensor( - const at::Tensor& input, - const std::optional& output, - int input_multiplier, - int output_multiplier, - int64_t ref_numel, - ScalarType ref_dtype) { - auto check_one = [&](const at::Tensor& tensor) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - throw std::runtime_error( - "input and output elements have to be xpu dense Tensors"); - } - - if (ref_dtype != tensor.scalar_type()) { - throw std::runtime_error( - "all inputs and outputs must be of the same Tensor dtype"); - } - - if (!tensor.is_contiguous()) { - throw std::runtime_error("all inputs and outputs have to be contiguous"); - } - }; - - check_one(input); - - // all inputs must be same size - if (input.numel() != ref_numel) { - throw std::runtime_error( - "all inputs must have the same number of elements"); - } - - if (output) { - check_one(*output); - - // inputs and outputs must be on same device respectively - if (input.get_device() != output->get_device()) { - throw std::runtime_error("input and output must be on the same device"); - } - - if (output->numel() * output_multiplier != ref_numel * input_multiplier) { - throw std::runtime_error( - "output must be of size input_size * size_multiplier"); - } - } -} - -// void check_inputs( -// TensorList inputs, -// TensorList outputs, -// int input_multiplier, -// int output_multiplier) { -// // len(inputs) == len(outputs) -// size_t len = inputs.size(); - -// if (len <= 0) { -// throw std::runtime_error("input sequence can't be empty"); -// } - -// if (len != outputs.size()) { -// std::stringstream err; -// err << "inputs and outputs sequences have to be of the same length, but got input of length " -// << len << " and output of length " << outputs.size(); -// throw std::runtime_error(err.str()); -// } - -// device_set devices; -// int64_t numel = inputs[0].numel(); -// auto dtype = inputs[0].scalar_type(); - -// for (const auto i : c10::irange(len)) { -// auto input = inputs[i]; -// auto output = outputs[i]; - -// check_tensor( -// input, output, input_multiplier, output_multiplier, numel, dtype); - -// auto input_device = input.get_device(); -// // inputs must be on unique devices -// if (devices.test(input_device)) { -// throw std::runtime_error("inputs must be on unique devices"); -// } -// devices.set(input_device); -// } -// } - -// void check_inputs( -// TensorList inputs, -// const at::Tensor& output, -// int root, -// int input_multiplier, -// int output_multiplier) { -// auto len = inputs.size(); - -// if (len <= 0) { -// throw std::runtime_error("input sequence can't be empty"); -// } - -// device_set devices; -// int64_t numel = inputs[0].numel(); -// auto dtype = inputs[0].scalar_type(); - -// for (const auto i : c10::irange(len)) { -// auto input = inputs[i]; - -// check_tensor( -// input, -// i == static_cast>(root) -// ? std::optional{output} -// : std::nullopt, -// input_multiplier, -// output_multiplier, -// numel, -// dtype); - -// auto input_device = input.get_device(); -// // inputs must be on unique devices -// if (devices.test(input_device)) { -// throw std::runtime_error("inputs must be on unique devices"); -// } -// devices.set(input_device); -// } -// } - -} // namespace detail - -// std::uint64_t version() { -// #if defined(NCCL_MAJOR) -// constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) | -// (((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH); -// return ver; -// #elif defined(USE_NCCL) -// // return major version "1" -// return ((uint64_t)1) << 32; -// #else -// return 0; -// #endif -// } - -// ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) -// { #ifdef USE_XCCL -// using namespace torch::xpu::xccl::detail; -// xcclComm_t comm; -// ncclUniqueId id = comm_id; -// NCCL_CHECK(ncclCommInitRank( -// to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank)); -// return comm; -// #else -// return nullptr; -// #endif -// } - -// namespace { - -// ret_evt = torch::xpu::xccl::all_reduce( -// input, -// output, -// datatype, -// xcclOp.at(opts.reduceOp), -// comm, -// attr, -// stream, -// root); - -// void all_reduce( -// at::Tensor& input, -// at::Tensor& output, -// ccl::datatype datatype, -// ccl::reduction op, -// const stream_list& streams, -// const comm_list& user_comms) { -// #ifdef USE_XCCL -// using namespace torch::cuda::nccl::detail; -// check_inputs(inputs, outputs, 1, 1); -// const auto len = inputs.size(); - -// auto data_type = to_nccl_data_type(inputs[0]); - -// const auto count = inputs[0].numel(); -// auto comms_ref = user_comms.empty() ? get_communicators(inputs) -// : ArrayRef(user_comms); - -// AutoNcclGroup nccl_group_guard; -// at::cuda::OptionalCUDAGuard device_guard; -// for (const auto i : c10::irange(len)) { -// auto device = inputs[i].device().index(); -// device_guard.set_index(device); -// // Default to the current stream -// const auto stream = (streams.empty() || !streams[i]) -// ? at::cuda::getCurrentCUDAStream(device).stream() -// : streams[i]->stream(); - -// ncclComm_t comm = comms_ref[i]; -// NCCL_CHECK(ncclAllReduce( -// inputs[i].data_ptr(), -// outputs[i].data_ptr(), -// count, -// data_type, -// to_nccl_red_op(op), -// to_nccl_comm(comm), -// stream)); -// } -// #else -// AT_ERROR("PyTorch built without NCCL support"); -// #endif -// } - -} // namespace torch::xpu::xccl diff --git a/torch/csrc/xpu/xccl.h b/torch/csrc/xpu/xccl.h deleted file mode 100644 index c7a67975bb286..0000000000000 --- a/torch/csrc/xpu/xccl.h +++ /dev/null @@ -1,77 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace torch::xpu::xccl { - -using xcclComm_t = ccl::communicator; - -using XCCL_KVS = ccl::shared_ptr_class; - -extern XCCL_KVS kvs; - -XCCL_KVS get_kvs(int rank, c10d::Store& store); - -enum class xcclRedOp { Sum = 0, Prod = 1, Max = 2, Min = 3 }; - -enum class xcclDataType { - Int8 = 0, - Char = 0, - Uint8 = 1, - Int32 = 2, - Int = 2, - Uint32 = 3, - Int64 = 4, - Uint64 = 5, - Float16 = 6, - Half = 6, - Float32 = 7, - Float = 7, - Float64 = 8, - Double = 8, - Bfloat16 = 9, - NumTypes = 10 -}; - -namespace detail { - -at::ArrayRef get_communicators(at::TensorList inputs); -void check_inputs( - at::TensorList inputs, - at::TensorList outputs, - int input_multiplier, - int output_multiplier); -void check_inputs( - at::TensorList inputs, - const at::Tensor& output, - int root, - int input_multiplier, - int output_multiplier); - -} // namespace detail - -// using comm_list = std::vector; -// using stream_list = std::vector>; - -std::uint64_t version(); -const char* version_suffix(); - -bool is_available(at::TensorList tensors); - -// comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank); -// void comm_destroy(xcclComm_t comm); - -// void all_reduce( -// const std::vector& inputs, -// std::vector& outputs, -// int32_t op = static_cast(xcclRedOp::Sum), -// const stream_list& streams = {}, -// const comm_list& user_comms = {}); -} // namespace torch::xpu::xccl