diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ff47e2cf21301..8a7aefdc238c4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -180,12 +180,13 @@ ncclRedOpRAII getNcclReduceOp( NCCL_MINOR)); break; case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + break; case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + break; case ReduceOp::BXOR: - C10_THROW_ERROR( - ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + - " with NCCL"); + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); break; default: C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index caf9aeee5a6bd..291a27a53559d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -16,14 +16,14 @@ namespace c10d { namespace { -std::map xcclOps = { +const std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, {ReduceOp::SUM, ccl::reduction::sum}, {ReduceOp::PRODUCT, ccl::reduction::prod}, }; -std::map xcclDatatypes = { +const std::map xcclDatatypes = { {at::kByte, ccl::datatype::uint8}, {at::kChar, ccl::datatype::int8}, {at::kInt, ccl::datatype::int32}, @@ -35,7 +35,7 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -void check_xpu_single_tensor(const at::Tensor& tensor) { +void checkXPUTensor(at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { C10_THROW_ERROR( ValueError, "Tensors must be XPU and dense and non-complex"); @@ -62,19 +62,15 @@ ccl::datatype getXcclDataType( ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { - if (input.scalar_type() == at::kBool) { - if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise - // or. This is to prevent overflow issues with sum, since we use uint8 - // to represent a bool (see xcclDatatypes mapping align with cuda). - return ccl::reduction::max; - } + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); } } @@ -85,7 +81,6 @@ void syncStream( xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); xcclEvent.block(xcclStream); } - } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; @@ -98,8 +93,7 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( : Work(rank, opType, "profilingTitle", inputs), device_(device), workStartTime_(std::chrono::steady_clock::now()) { - unsigned char enable_timing = 0; - xcclEndEvent_ = std::make_shared(enable_timing); + xcclEndEvent_ = std::make_shared(); } ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) @@ -133,12 +127,11 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( currentTimepoint - workStartTime_); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( - "Work ran for ", + "Work ran time out after ", timeElapsed.count(), - " milliseconds before timing out."); + " milliseconds."); TORCH_CHECK(false, exceptionMsg) } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -157,20 +150,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); - - // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI - // launchers. - if (!with_mpirun()) { - int local_rank = getXCCLEnvVar("LOCAL_RANK"); - int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); - if (local_rank == -1 || local_world_size == -1) { - local_rank = rank; - local_world_size = size; - } - setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); - setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); - setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); - } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -189,24 +168,16 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device) { - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the XCCL Communicator since " - "the devices are empty "); - } - + TORCH_CHECK_WITH(DistBackendError, deviceKey.empty(), "Not able to create/get " + "XCCL Communicator since the devices are empty "); { + // todo: why do we need mutex here? std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { return devXCCLCommMap_[deviceKey]; } } - std::shared_ptr XCCLComm; - - XCCL_KVS kvs = get_kvs(rank_, *store_); - int numRanks, rank; numRanks = getSize(); rank = getRank(); @@ -219,14 +190,12 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); - XCCLComm = std::make_shared(std::move(comms[0])); - - { - std::lock_guard lock(mutex_); - devXCCLCommMap_.emplace(deviceKey, XCCLComm); - } + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = std::make_shared(std::move(comms[0])); + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace(deviceKey, std::move(stream)); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); @@ -250,11 +219,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::intrusive_ptr work; work = initWork(device, rank_, opType); - work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); - pre(stream, work); for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( @@ -264,7 +231,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( post(stream, work); work->xcclEndEvent_->record(stream); - std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -282,7 +248,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( TORCH_CHECK( tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); + checkXPUTensor(tensor); return collective( tensor, tensor, diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 41b78035c83f9..155082fe6d734 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -35,7 +35,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -113,9 +112,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& output, Fn fn, OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; return collective( - input, - output, + inputs, + outputs, fn, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, @@ -124,19 +125,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { opType); } - template - c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType); - } - template c10::intrusive_ptr collective( std::vector& inputs, @@ -159,9 +147,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { bool blockingWait_ = false; private: - XCCL_KVS kvs; std::mutex kvs_mutex; - XCCL_KVS get_kvs(int rank, c10d::Store& store) { + ccl::shared_ptr_class kvs; + + ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { + // todo: why do we need the mutex here? std::lock_guard lock(kvs_mutex); if (kvs) return kvs; @@ -186,41 +176,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { return kvs; } }; - -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} - -} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 73e37e0437c45..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,7 +557,7 @@ size_t computeLengthsAndOffsets( return offset; } -inline std::string reduce_op_to_string(c10d::ReduceOp op) { +inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: return "SUM";