diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 29571065c..1881cf3aa 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,7 +6,12 @@ include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) -set(XCCL_ROOT $ENV{CCL_ROOT}) +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneCCL") + set(XCCL_ROOT $ENV{CCL_ROOT}) +endif() + string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) if(nocclfound) set(XCCL_FOUND False) diff --git a/src/xccl/CMakeLists.txt b/src/xccl/CMakeLists.txt index 242cf7e26..7e16ea8ff 100644 --- a/src/xccl/CMakeLists.txt +++ b/src/xccl/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp}) set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE) +# Copy the header file to the build directory so that the PyTorch registration file can locate it. foreach(HEADER ${xccl_h}) file(COPY ${HEADER} DESTINATION "${CMAKE_BINARY_DIR}/torch/csrc/distributed/c10d") endforeach() diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 71e26fbe0..4a82e3cbc 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -24,114 +24,24 @@ const std::map xcclDatatypes = { {at::kDouble, ccl::datatype::float64}, {at::kBFloat16, ccl::datatype::bfloat16}, {at::kBool, ccl::datatype::uint8}, - // use for allgather - {at::kFloat8_e5m2, ccl::datatype::uint8}, - {at::kFloat8_e4m3fn, ccl::datatype::uint8}, - {at::kFloat8_e4m3fnuz, ccl::datatype::uint8}, - {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; -bool computeLengthsAndCheckAndGetFlat( - const std::vector& tensors, - std::vector& lengths, - at::Tensor& flatTensor, - int64_t& flatLength) { - int64_t groupSize = tensors.size(); - auto firstTensor = tensors[0]; - int64_t totalSize = 0; - bool isFlat = true; - - auto storage = firstTensor.storage(); - int64_t firstStorageOffset = firstTensor.storage_offset(); - - for (int i = 0; i < groupSize; i++) { - auto& curTensor = tensors[i]; - int64_t length = curTensor.numel(); - lengths[i] = length; - totalSize += length; - - if (isFlat && - (!storage.is_alias_of(curTensor.storage()) || - curTensor.storage_offset() != - firstStorageOffset + totalSize - length)) { - isFlat = false; - } - } - - flatLength = totalSize; - - if (isFlat) { - flatTensor = firstTensor; - } else { - flatTensor = at::empty({totalSize}, firstTensor.options()); - } - - return isFlat; -} - -bool check_same_size(const std::vector& input_tensors) { - for (const auto& input_tensor : input_tensors) { - if (!input_tensors[0].is_same_size(input_tensor)) { - return false; - } - } - return true; -} - -void check_xpu_single_tensor( - const at::Tensor& tensor, - const bool p2p = false // whether operation is a P2P operation -) { +void checkXPUTensor(at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { C10_THROW_ERROR( ValueError, "Tensors must be XPU and dense and non-complex"); - - // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - if (p2p) { - TORCH_WARN_ONCE( - "Detected non-contiguous tensor in P2P operations. It is user " - "responsibility to guarantee that source and destination tensors have " - "the same contiguity format."); - } else { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); - } + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); } } } -int64_t check_xpu_tensors_same_device(const std::vector& tensors) { - TORCH_CHECK_WITH( - ValueError, tensors.size() != 0, "Tensor list must be nonempty"); - - const auto& first = tensors.front(); - - int64_t total_numel = 0; - for (const auto& t : tensors) { - if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { - C10_THROW_ERROR( - ValueError, "Tensors must be XPU and dense and non-complex"); - } - if (t.scalar_type() != first.scalar_type()) { - C10_THROW_ERROR(TypeError, "Tensors must have identical type"); - } - TORCH_CHECK_WITH( - ValueError, - t.get_device() == tensors[0].get_device(), - "Expected list of tensors on the same device"); - total_numel += t.numel(); - } - - return total_numel; -} - ccl::datatype getXcclDataType( at::ScalarType type, bool is_reduction_op = false) { - if (is_reduction_op) - TORCH_CHECK( - !isFloat8Type(type), - "Float8 dtypes are not currenlty supported for XCCL reductions"); + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -147,10 +57,6 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { // Map sum to max for bool tensors to avoid overflow issues with sum. return ccl::reduction::max; } - // WA due to oneCCL not support AVG - if (reduceOp == ReduceOp::AVG) { - return ccl::reduction::sum; - } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( @@ -166,11 +72,9 @@ void syncStream( xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); xcclEvent.block(xcclStream); } - } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; -thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, @@ -225,10 +129,6 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } } - if (barrierTensor_.defined()) { - auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - currentStream.synchronize(); - } } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { @@ -236,9 +136,6 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } -constexpr const char* MULTI_DEVICE_ERROR_MSG = - "Expecting one tensor only but got multiple"; - ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, @@ -250,12 +147,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( ProcessGroupXCCL::~ProcessGroupXCCL() = default; -void ProcessGroupXCCL::setSequenceNumberForGroup() {} - -uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { - return seqCollective_; -} - c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, @@ -275,19 +166,12 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, - at::Device& device, - OpType opType, - int p2pRank, - bool isSendRecvSelf) { - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the XCCL Communicator since " - "the devices are empty "); - } - - usedDeviceIdxs_.insert(device.index()); - + at::Device& device) { + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " + "XCCL Communicator since the devices are empty "); { std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { @@ -295,24 +179,9 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } } - std::shared_ptr XCCLComm; - - bool batchP2P = xcclActiveGroupCounter_ > 0; - bool singleP2POp = isP2POp(opType, batchP2P); - - at::xpu::OptionalXPUGuard gpuGuard(device); - int numRanks, rank; - if (!singleP2POp) { - numRanks = getSize(); - rank = getRank(); - } else if (isSendRecvSelf) { - numRanks = 1; - rank = 0; - } else { - numRanks = 2; - rank = p2pRank; - } + numRanks = getSize(); + rank = getRank(); c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = @@ -323,23 +192,10 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); + auto xccl_kvs = get_kvs(rank_, *store_); auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - XCCLComm = std::make_shared(std::move(comms[0])); - - RECORD_PARAM_COMMS( - 0, // seq - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank, // rank - "init", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - size_); // worldSize + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); @@ -349,64 +205,6 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( return XCCLComm; } -void ProcessGroupXCCL::groupStart() { - ccl::group_start(); - ++xcclActiveGroupCounter_; -} - -void ProcessGroupXCCL::groupEnd() { - ccl::group_end(); - --xcclActiveGroupCounter_; -} - -// TODO: wait p2p enable -static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; -void ProcessGroupXCCL::startCoalescing() { - if (coalescing_state_ & CoalP2P) { - seqP2P_++; - } else { - seqCollective_++; - } - coalescedDevice_.set_index(-1); - coalescedComm_ = nullptr; - coalescing_state_ |= CoalActive; - groupStart(); -} - -c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { - if (coalescedComm_ == nullptr) { - // There is no actual work being coalesced, return here - groupEnd(); - coalescing_state_ = 0; - return nullptr; - } - TORCH_CHECK( - coalescedDevice_.index() >= 0, - "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); - - auto comm = coalescedComm_; - auto device = coalescedDevice_; - - const auto key = std::to_string(device.index()); - auto stream = xcclStreamsMap_.at(key); - - auto work = initWork(device, rank_, optype); - work->blockingWait_ = blockingWait_; - - groupEnd(); - - work->xcclEndEvent_->record(stream); - - coalescing_state_ = 0; - coalescedComm_ = nullptr; - return work; -} - -c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { - // Default OpType to COALESCED if not specified - return endCoalescing(OpType::COALESCED); -} - template c10::intrusive_ptr ProcessGroupXCCL::collective( std::vector& inputs, @@ -417,49 +215,28 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( OpType opType, const char* profilingTitle) { seqCollective_++; + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device, opType); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalColl; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } + auto comm = getXCCLComm(key, device); auto stream = xcclStreamsMap_.at(key); syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - + work = initWork(device, rank_, opType, profilingTitle); work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); - pre(stream, work); - for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); fn(inputs[i], outputs[i], *comm, stream); } - post(stream, work); - if (!coalescing_state_) { - work->xcclEndEvent_->record(stream); - } - + work->xcclEndEvent_->record(stream); std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -471,468 +248,13 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - const char* profilingTitle) { - auto device = tensor.device(); - std::string key; - int p2pRank = 0, p2pTargetRank = 0; - bool isSendRecvSelf = false; - - bool batchP2P = xcclActiveGroupCounter_ > 0; - if (batchP2P) { - key = std::to_string(device.index()); - p2pRank = rank_; - p2pTargetRank = peer; - } else { - int lowRank = rank_ < peer ? rank_ : peer; - int highRank = rank_ < peer ? peer : rank_; - key = std::to_string(lowRank) + ":" + std::to_string(highRank); - p2pRank = rank_ <= peer ? 0 : 1; - isSendRecvSelf = rank_ == peer; - p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; - if (!coalescing_state_) { - seqP2P_++; - } - } - - auto comm = getXCCLComm(key, device, opType, p2pRank, isSendRecvSelf); - - if (coalescing_state_ & CoalActive) { - coalescing_state_ |= CoalP2P; - if (coalescedDevice_.index() < 0) { - coalescedDevice_ = device; - } else { - TORCH_CHECK( - coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); - } - if (coalescedComm_ == nullptr) { - coalescedComm_ = comm; - } else { - TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); - } - } - - auto stream = xcclStreamsMap_.at(key); - syncStream(device, xcclEventsMap_[key], stream); - - if (!coalescing_state_) { - c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - work->outputs_ = std::make_shared>(); - work->outputs_->push_back(tensor); - - at::xpu::OptionalXPUGuard gpuGuard(device); - - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); - - fn(tensor, *comm, stream, p2pTargetRank); - - work->xcclEndEvent_->record(stream); - work->blockingWait_ = blockingWait_; - std::vector streams = {stream.unwrap()}; - c10::MultiStreamGuard streamGuard(streams); - std::vector devices{device}; - work->future_ = c10::make_intrusive( - c10::ListType::create(c10::TensorType::get()), devices); - work->future_->markCompleted(at::IValue(*work->outputs_)); - return work; - } else { - at::xpu::OptionalXPUGuard gpuGuard(device); - - c10::xpu::XPUCachingAllocator::recordStream( - tensor.storage().data_ptr(), stream); - - fn(tensor, *comm, stream, p2pTargetRank); - - return nullptr; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::send( - std::vector& tensors, - int dstRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - dstRank, // dst rank - "send", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& input, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - int dst) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::send( - input.data_ptr(), - (size_t)input.numel(), - xcclDataType, - dst, - comm, - ccl::create_stream(stream.queue())); - return; - }, - dstRank, - OpType::SEND, - c10::str("xccl:send ", rank_, "->", dstRank).c_str()); - return ret; -} - -c10::intrusive_ptr ProcessGroupXCCL::recv( - std::vector& tensors, - int srcRank, - int /* unused */) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor, true); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - srcRank, // src rank - "recv", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto ret = pointToPoint( - tensor, - [&](at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream, - int src) { - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::recv( - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - src, - comm, - ccl::create_stream(stream.queue())); - return; - }, - srcRank, - OpType::RECV, - c10::str("xccl:recv ", rank_, "<-", srcRank).c_str()); - return ret; -} - -c10::intrusive_ptr ProcessGroupXCCL::gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::gather: " + msg); - }; - - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - - std::vector outputs; - - if (getRank() == opts.rootRank) { - if (outputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element output list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (outputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect output list size " << outputTensors[0].size() - << ". Output list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } - - const auto& options = inputTensor.options(); - const auto& sizes = inputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes); - outputs = outputTensors[0]; - } else { - // if not in the root rank, initialize outputs as empty list - if (outputTensors.size() != 0) { - invalidArgument("requires empty output on non-root"); - } - outputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - outputs.emplace_back(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * this->getSize(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - auto inputs = std::vector{inputTensor}; - return collective( - inputs, - outputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const auto root = opts.rootRank; - if (getRank() == root) { - for (auto output : outputs) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - { - auto xcclDataType = getXcclDataType(inputTensor.scalar_type()); - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do receive - ccl::recv( - outputs[r].data_ptr(), - (size_t)inputTensor.numel(), - xcclDataType, - r, - comm, - ccl::create_stream(stream.queue())); - } else { - // on its own rank, simply copy from the input - outputs[r].copy_(inputTensor); - } - } - } else { - // do send - ccl::send( - inputTensor.data_ptr(), - (size_t)inputTensor.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - } - return; - } - }, - OpType::GATHER); -} - -c10::intrusive_ptr ProcessGroupXCCL::scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts) { - static auto invalidArgument = [](const std::string& msg) { - C10_THROW_ERROR(ValueError, "ProcessGroupXCCL::scatter: " + msg); - }; - - assertRootRank(invalidArgument, opts.rootRank, size_); - - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto outputTensor = outputTensors.back(); - - std::vector inputs; - - if (getRank() == opts.rootRank) { - if (inputTensors.size() != 1) { - std::stringstream ss; - ss << "requires a single-element input list containing a list with " - << getSize() << " tensors."; - invalidArgument(ss.str()); - } else if (inputTensors[0].size() != static_cast(getSize())) { - std::stringstream ss; - ss << "Incorrect input list size " << inputTensors[0].size() - << ". Input list size should be " << getSize() - << ", same as size of the process group."; - invalidArgument(ss.str()); - } - - const auto& options = outputTensor.options(); - const auto& sizes = outputTensor.sizes(); - assertTypeAndSizesMatch(invalidArgument, inputTensors[0], options, sizes); - inputs = inputTensors[0]; - } else { - // if not in the root rank, initialize inputTensors as empty place holder - // with an empty list - if (inputTensors.size() != 0) { - invalidArgument("requires empty input on non-root"); - } - inputs = {}; - // append a empty tensor to the list, we don't use it but the - // `collective` template function requires it to invoke its function - inputs.emplace_back(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - opts.rootRank, // root rank - "scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - const auto root = opts.rootRank; - - auto outputs = std::vector{outputTensor}; - return collective( - outputs, - inputs, // just to fit the collective interface - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - if (getRank() == root) { - for (auto input : inputs) { - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); - } - } - { - if (rank_ == root) { - for (const auto r : c10::irange(size_)) { - if (r != root) { - // do send - size_t send_count = inputs[r].numel(); - auto send_type = getXcclDataType(inputs[r].scalar_type()); - ccl::send( - inputs[r].data_ptr(), - send_count, - send_type, - r, - comm, - ccl::create_stream(stream.queue())); - } else { - // on its own rank, simply copy from the input - outputTensor.copy_(inputs[r]); - } - } - } else { - // do receive - size_t recv_count = outputTensor.numel(); - auto recv_type = getXcclDataType(outputTensor.scalar_type()); - ccl::recv( - outputTensor.data_ptr(), - recv_count, - recv_type, - root, - comm, - ccl::create_stream(stream.queue())); - } - - return; - } - }, - OpType::SCATTER); -} - -c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( - at::Tensor& tensor, - const AllreduceOptions& opts) { - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto ccl_stream = ccl::create_stream(stream.queue()); - ccl::allreduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::ALLREDUCE, - "xccl:all_reduce"); -} - c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - size_); // worldSize + checkXPUTensor(tensor); return collective( tensor, @@ -943,6 +265,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -950,845 +273,13 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } + ccl_stream); return; }, OpType::ALLREDUCE, "xccl:all_reduce"); } -c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts) { - auto total_numel = check_xpu_tensors_same_device(tensors); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce_coalesced", // collective name - total_numel, // inNelems - total_numel, // outNelems - tensors[0].scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collectiveCoalesced( - tensors, - tensors, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::allreduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::COALESCED, - "xccl:allreduce_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::broadcast( - std::vector& tensors, - const BroadcastOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - // @lint-ignore CLANGTIDY - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "broadcast", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - const auto root = opts.rootRank + opts.rootTensor; - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( - input.data_ptr(), - (size_t)input.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::BROADCAST, - "nccl:broadcast"); -} - -c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const BroadcastOptions& opts) { - if (outputTensor.numel() != inputTensor.numel()) { - C10_THROW_ERROR( - ValueError, - "Tensor input and output of _broadcast_oop must have the same number of elements "); - } - const auto root = opts.rootRank + opts.rootTensor; - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::broadcast( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - root, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::BROADCAST, - "xccl:_broadcast_oop"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce( - std::vector& tensors, - const ReduceOptions& opts) { - TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - opts.rootRank, // root rank - "reduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - tensor, - tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::REDUCE, - "xccl:reduce"); -} - -c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceOptions& opts) { - TORCH_CHECK_WITH( - ValueError, - outputTensor.numel() == inputTensor.numel(), - "Tensor input and output of _reduce_oop must have the same number of elements"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - const int root = opts.rootRank + opts.rootTensor; - const auto xcclDataType = getXcclDataType(input.scalar_type(), true); - const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - xcclReduceOp, - root, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::REDUCE, - "xccl:_reduce_oop"); -} - -c10::intrusive_ptr ProcessGroupXCCL::allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts) { - TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto inputTensor = inputTensors.back(); - check_xpu_single_tensor(inputTensor); - // @lint-ignore CLANGTIDY - std::vector& outputTensors_ = outputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_gather", // collective name - inputTensor.numel(), // inNelems - inputTensor.numel() * // outNelems - this->getSize(), - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(outputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor outputFlattened = newLikeFlat(outputTensors_); - - return collective( - inputTensor, - outputFlattened, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the flattened output tensors to the outputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(outputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), Stream); - outputTensors_[j].copy_(outputFlattened[j], true); - } - }, - OpType::ALLGATHER, - "xccl:all_gather"); - } else { - const auto num_reduces = outputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& output = outputTensors_[i]; - auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{ - static_cast(i), static_cast(0), opts.timeout}; - _broadcast_oop(output, input, broadcastOpts); - } - auto work = endCoalescing(OpType::ALLGATHER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( - at::Tensor& output_tensor, - at::Tensor& input_tensor, - const AllgatherOptions& opts) { - check_xpu_single_tensor(input_tensor); - check_xpu_single_tensor(output_tensor); - - TORCH_CHECK_WITH( - TypeError, - input_tensor.dtype() == output_tensor.dtype(), - "output tensor must have the same type as input tensor"); - TORCH_CHECK_WITH( - ValueError, - input_tensor.numel() * size_ == output_tensor.numel(), - "output tensor size must be equal to world_size times input tensor size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - input_tensor, // inputTensors - output_tensor, // outputTensors - rank_, // rank - "_allgather_base", // collective name - input_tensor.numel(), // inNelems - output_tensor.numel(), // outNelems - output_tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSize - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - input_tensor, - output_tensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::_ALLGATHER_BASE, - "xccl:_all_gather_base"); -} - -c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); - ccl::allgather( - input.data_ptr(), - output.data_ptr(), - (size_t)input.numel(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::COALESCED, - "xccl:all_gather_into_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts) { - TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); - // @lint-ignore CLANGTIDY - auto outputTensor = outputTensors.back(); - check_xpu_single_tensor(outputTensor); - // @lint-ignore CLANGTIDY - auto inputTensors_ = inputTensors.back(); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "reduce_scatter", // collective name - outputTensor.numel() * this->getSize(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - bool same_size = check_same_size(inputTensors_); - if (same_size) { - // Flatten a vector of tensors into a single, stacked tensor. - at::Tensor inputFlattened = newLikeFlat(inputTensors_); - return collective( - inputFlattened, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - [&](at::xpu::XPUStream& Stream, - c10::intrusive_ptr& work) { - // Copy the input tensors to the flattened inputs. - c10::StreamGuard guard(Stream); - for (const auto j : c10::irange(inputTensors_.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), Stream); - inputFlattened[j].copy_(inputTensors_[j], true); - } - }, - [&](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - OpType::REDUCE_SCATTER, - "xccl:reduce_scatter"); - } else { - const auto num_reduces = inputTensors_.size(); - startCoalescing(); - for (const int i : c10::irange(num_reduces)) { - auto& input = inputTensors_[i]; - auto& output = (i == rank_) ? outputTensor : input; - auto reduceOpts = ReduceOptions{ - opts.reduceOp, - static_cast(i), - static_cast(0), - opts.timeout}; - _reduce_oop(output, input, reduceOpts); - } - auto work = endCoalescing(OpType::REDUCE_SCATTER); - return work; - } -} - -c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts) { - TORCH_CHECK_WITH( - TypeError, - inputTensor.dtype() == outputTensor.dtype(), - "input tensor must be the same type as the output tensor."); - TORCH_CHECK_WITH( - ValueError, - inputTensor.numel() == outputTensor.numel() * size_, - "input tensor must be the same size as output size times world size"); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "_reduce_scatter_base", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - outputTensor.scalar_type(), // dtype - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::_REDUCE_SCATTER_BASE, - "xccl:_reduce_scatter_base"); -} - -c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts) { - return collectiveCoalesced( - inputs, - outputs, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(input.scalar_type(), true); - auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::reduce_scatter( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel(), - xcclDataType, - xcclReduceOp, - comm, - ccl::create_stream(stream.queue())); - // WA due to oneCCL not support AVG - if (opts.reduceOp == ReduceOp::AVG) { - auto divisor = getSize(); - output.div_(divisor); - } - return; - }, - OpType::COALESCED, - "xccl:reduce_scatter_tensor_coalesced"); -} - -c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { - RECORD_PARAM_COMMS( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - rank_, // rank - "barrier", // collective name - 0, // inNelems - 0, // outNelems - at::kByte, // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - // Device to use for barrier - int barDevIdx = -1; - - // See nccl barrier comments - if (!opts.device_ids.empty()) { - barDevIdx = opts.device_ids[0]; - } else if (getBoundDeviceId()) { - barDevIdx = (*getBoundDeviceId()).index(); - } else if (!usedDeviceIdxs_.empty()) { - barDevIdx = *usedDeviceIdxs_.begin(); - } else { - barDevIdx = - static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); - } - - // todo: use barrier instead of allreduce - TORCH_CHECK_WITH( - ValueError, - barDevIdx >= 0, - "Failed to infer a GPU device id to perform barrier. "); - auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); - - at::Tensor barrierTensor = - at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); - - auto work = allreduce_impl(barrierTensor); - - auto xcclWork = dynamic_cast(work.get()); - TORCH_CHECK(xcclWork); - xcclWork->barrierTensor_ = std::move(barrierTensor); - return work; -} - -c10::intrusive_ptr ProcessGroupXCCL::alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { - check_xpu_single_tensor(outputTensor, true); - check_xpu_single_tensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_all", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - TORCH_CHECK( - outputTensor.numel() == inputTensor.numel() && - outputTensor.scalar_type() == inputTensor.scalar_type(), - "xpu_alltoall_base: tensors are not equal in size or data type"); - TORCH_CHECK( - outputTensor.size(0) % size_ == 0, - "xpu_alltoall_base: tensor's dim 0 does not divide equally across group size"); - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::xpu::XPUCachingAllocator::recordStream( - output.storage().data_ptr(), stream); - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoall( - input.data_ptr(), - output.data_ptr(), - (size_t)output.numel() / comm.size(), - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } else { - c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); - c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensor, // inputTensor - outputTensor, // outputTensor - rank_, // rank - "all_to_allv", // collective name - inputTensor.numel(), // inNelems - outputTensor.numel(), // outNelems - inputTensor.scalar_type(), // dType - inputSplitSizes, // inSplitSizes - outputSplitSizes, // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensor, - outputTensor, - [&](at::Tensor& input, - at::Tensor& output, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - std::vector sendCounts(size_); - std::vector recvCounts(size_); - bool inputSplitsEqual = inputSplitSizes.size() == 0; - bool outputSplitsEqual = outputSplitSizes.size() == 0; - - size_t inLen = input.numel(); - size_t outLen = output.numel(); - if (inLen) - inLen /= (inputSplitsEqual ? size_ : input.size(0)); - if (outLen) - outLen /= (outputSplitsEqual ? size_ : output.size(0)); - - for (int i = 0; i < size_; i++) { - sendCounts[i] = - (inputSplitsEqual ? inLen : inputSplitSizes[i] * inLen); - recvCounts[i] = - (outputSplitsEqual ? outLen : outputSplitSizes[i] * outLen); - } - auto xcclDataType = getXcclDataType(output.scalar_type()); - ccl::alltoallv( - input.data_ptr(), - sendCounts, - output.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - return; - }, - OpType::ALLTOALL_BASE, - "xccl:all_to_all"); - } -} - -c10::intrusive_ptr ProcessGroupXCCL::alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& /* unused */) { - auto device = outputTensors[0].device(); - int64_t total_numel = 0; - for (const auto r : c10::irange(outputTensors.size())) { - check_xpu_single_tensor(outputTensors[r], true); - check_xpu_single_tensor(inputTensors[r], true); - TORCH_CHECK( - device == outputTensors[r].device() && - device == inputTensors[r].device(), - "Tensors must be on the same device") - total_numel += inputTensors[r].numel(); - } - - RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - inputTensors, // inputTensors - outputTensors, // outputTensors - rank_, // rank - "all_to_all", // collective name - total_numel, // inNelems - total_numel, // outNelems - inputTensors.front().scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - -1, // globalRankStart - -1, // globalRankStride - this->getSize()); // worldSize - - return collective( - inputTensors, - outputTensors, - [&](at::Tensor& /* unused */, - at::Tensor& /* unused */, - xcclComm_t& comm, - at::xpu::XPUStream& stream) { - c10::OptionalStreamGuard stream_guard(stream.unwrap()); - at::Tensor flatInput; - at::Tensor flatOutput; - - std::vector sendCounts(size_); - std::vector recvCounts(size_); - - int64_t flatSendCount; - int64_t flatRecvCount; - - bool isInputFlat = computeLengthsAndCheckAndGetFlat( - inputTensors, sendCounts, flatInput, flatSendCount); - bool isOutputFlat = computeLengthsAndCheckAndGetFlat( - outputTensors, recvCounts, flatOutput, flatRecvCount); - if (!isInputFlat) { - auto flatInputSplits = flatInput.split_with_sizes( - c10::IntArrayRef((int64_t*)sendCounts.data(), sendCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - flatInputSplits[i].copy_(inputTensors[i].view({-1})); - } - } - - auto xcclDataType = getXcclDataType(flatOutput.scalar_type()); - ccl::event ret_evt; - ret_evt = ccl::alltoallv( - flatInput.data_ptr(), - sendCounts, - flatOutput.data_ptr(), - recvCounts, - xcclDataType, - comm, - ccl::create_stream(stream.queue())); - - if (!isOutputFlat) { - ret_evt.wait(); - auto flatOutputSplits = flatOutput.split_with_sizes( - c10::IntArrayRef((int64_t*)recvCounts.data(), recvCounts.size()), - 0); - - for (int i = 0; i < size_; i++) { - outputTensors[i].view({-1}).copy_(flatOutputSplits[i]); - } - } - stream.synchronize(); - return; - }, - OpType::ALLTOALL, - "xccl:all_to_all"); -} - } // namespace c10d #endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index e2099e658..21269bd6f 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -26,7 +26,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -68,7 +67,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: at::Device device_; std::shared_ptr xcclEndEvent_; - at::Tensor barrierTensor_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; uint64_t seq_; @@ -95,18 +93,9 @@ class TORCH_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } - void startCoalescing() override; - - c10::intrusive_ptr endCoalescing() override; - - c10::intrusive_ptr endCoalescing(OpType optype); - std::shared_ptr getXCCLComm( const std::string& deviceKey, - at::Device& device, - OpType opType, - int p2pRank = 0, - bool isSendRecvSelf = false); + at::Device& device); virtual c10::intrusive_ptr initWork( at::Device& device, @@ -123,39 +112,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) {}, - opType, - profilingTitle); - } - - template - c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType, - const char* profilingTitle = nullptr) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); - } - - template - c10::intrusive_ptr collective( - std::vector& inputs, - std::vector& outputs, - Fn fn, - OpType opType, - const char* profilingTitle = nullptr) { return collective( inputs, outputs, @@ -164,8 +122,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr&) {}, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, - opType, - profilingTitle); + opType); } template @@ -178,140 +135,14 @@ class TORCH_API ProcessGroupXCCL : public Backend { OpType opType, const char* profilingTitle = nullptr); - template - c10::intrusive_ptr collectiveCoalesced( - std::vector& input, - std::vector& output, - Fn fn, - OpType opType, - const char* profilingTitle = nullptr) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) { - ccl::group_start(); - }, - [](at::xpu::XPUStream&, - c10::intrusive_ptr&) { - ccl::group_end(); - }, - opType, - profilingTitle); - } - - template - c10::intrusive_ptr pointToPoint( - at::Tensor& tensor, - Fn fn, - int peer, - OpType opType, - const char* profilingTitle = nullptr); - - c10::intrusive_ptr allreduce_impl( - at::Tensor& tensor, - const AllreduceOptions& opts = AllreduceOptions()); - c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override; - - c10::intrusive_ptr reduce( - std::vector& tensors, - const ReduceOptions& opts = ReduceOptions()) override; - - c10::intrusive_ptr _reduce_oop( - at::Tensor& outputTensors, - at::Tensor& inputTensors, - const ReduceOptions& opts = ReduceOptions()); - - c10::intrusive_ptr broadcast( - std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override; - - c10::intrusive_ptr _broadcast_oop( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const BroadcastOptions& opts); - - c10::intrusive_ptr allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr _allgather_base( - at::Tensor& outputbuffer, - at::Tensor& inputbuffer, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts = AllgatherOptions()) override; - - c10::intrusive_ptr reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; - - c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) override; - - c10::intrusive_ptr alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) override; - - c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) override; - - c10::intrusive_ptr send( - std::vector& tensors, - int dstRank, - int tag) override; - - c10::intrusive_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override; - - void groupStart(); - - void groupEnd(); - - c10::intrusive_ptr gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts = GatherOptions()) override; - - c10::intrusive_ptr scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts = ScatterOptions()) override; - - void setSequenceNumberForGroup() override; - - uint64_t getSequenceNumberForGroup() override; + void setSequenceNumberForGroup() override {} + uint64_t getSequenceNumberForGroup() override { + return seqCollective_; + } protected: std::unordered_map xcclStreamsMap_; @@ -320,14 +151,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; uint64_t xcclCommCounter_{0}; std::mutex mutex_; - std::set usedDeviceIdxs_; - int coalescing_state_ = 0; - at::Device coalescedDevice_ = at::Device("xpu"); - std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; - static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; - uint64_t seqP2P_{0}; private: std::mutex kvs_mutex;