diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b4ec018019f16..25bd7f700f68a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1376,9 +1376,6 @@ if(USE_DISTRIBUTED) endif() if(USE_XPU AND USE_C10D_XCCL) target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) - set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp - PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL") endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index a717ad1dafc65..18f7ac642d54e 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,9 +6,10 @@ include(FindPackageHandleStandardArgs) -set(XCCL_ROOT "") -if(DEFINED ENV{CCL_ROOT}) - set(XCCL_ROOT $ENV{CCL_ROOT}) +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) endif() string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 903df26bba9f6..d3cb65f7befb1 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -31,6 +31,7 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, + get_device_count, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -60,17 +61,13 @@ torch.backends.cuda.matmul.allow_tf32 = False -def gpus_for_rank(world_size): +def gpus_for_rank(world_size, backend): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - device_count = ( - torch.xpu.device_count() - if torch.xpu.is_available() - else torch.cuda.device_count() - ) + device_count = get_device_count(backend) visible_devices = list(range(device_count)) gpus_per_process = device_count // world_size gpus_for_rank = [] @@ -833,7 +830,7 @@ def update_parameters(model): def _gpu_model_with_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False, state=None ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -850,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook( def _gpu_model_with_builtin_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index c7f9609bcf0cd..d360b219edc5d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -51,8 +51,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { NCCL = 2, UCC = 3, MPI = 4, - CUSTOM = 5, - XCCL = 6, + XCCL = 5, + CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { @@ -131,6 +131,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendType_; }; + inline bool backendSupportsSequenceNumbers(BackendType backendType) { + if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || + backendType == BackendType::XCCL || backendType == BackendType::UCC) + return true; + return false; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -508,10 +515,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual void setSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { getDefaultBackend()->setSequenceNumberForGroup(); } else { TORCH_CHECK( @@ -530,10 +534,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 365640d137778..a89a7c48a01ff 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,40 +1,20 @@ #ifdef USE_C10D_XCCL #include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include namespace c10d { namespace { - -// wait nonblocking implement -AutoXcclGroup::AutoXcclGroup() { - ccl::group_start(); -} - -AutoXcclGroup::~AutoXcclGroup() noexcept(false) { - ccl::group_end(); -} - -std::map xcclOps = { +const std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, {ReduceOp::SUM, ccl::reduction::sum}, {ReduceOp::PRODUCT, ccl::reduction::prod}, }; -std::map xcclDatatypes = { +const std::map xcclDatatypes = { {at::kByte, ccl::datatype::uint8}, {at::kChar, ccl::datatype::int8}, {at::kInt, ccl::datatype::int32}, @@ -158,19 +138,15 @@ ccl::datatype getXcclDataType(at::ScalarType type) { ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { - if (input.scalar_type() == at::kBool) { - if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise or. - // This is to prevent overflow issues with sum, since we use uint8 to - // represent a bool (see xcclDatatypes mapping align with cuda). - return ccl::reduction::max; - } + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); } } @@ -196,8 +172,6 @@ bool complexViewAsRealAllowed(const ReduceOp reduceOp) { } // namespace -static std::mutex xcclCommDevIdxMapMutex; -static std::unordered_map, int> xcclCommDevIdxMap; constexpr int64_t kSynchronizeBusyWaitMillis = 10; thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; @@ -205,12 +179,14 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, OpType opType, + uint64_t seq, + const char* profilingTitle, const std::optional>& inputs) - : Work(rank, opType, "profilingTitle", inputs), + : Work(rank, opType, profilingTitle, inputs), device_(device), - workStartTime_(std::chrono::steady_clock::now()) { - unsigned char enable_timing = 0; - xcclEndEvent_ = std::make_shared(enable_timing); + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { + xcclEndEvent_ = std::make_shared(); } ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) @@ -218,7 +194,8 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) device_(w.device_), xcclEndEvent_(w.xcclEndEvent_), blockingWait_(w.blockingWait_), - workStartTime_(w.workStartTime_) {} + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; @@ -244,12 +221,9 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( currentTimepoint - workStartTime_); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( - "Work ran for ", - timeElapsed.count(), - " milliseconds before timing out."); + "Work ran time out after ", timeElapsed.count(), " milliseconds."); TORCH_CHECK(false, exceptionMsg) } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -275,20 +249,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); - - // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI - // launchers. - if (!with_mpirun()) { - int local_rank = getXCCLEnvVar("LOCAL_RANK"); - int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); - if (local_rank == -1 || local_world_size == -1) { - local_rank = rank; - local_world_size = size; - } - setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); - setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); - setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); - } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -297,10 +257,16 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, OpType opType, + const char* profilingTitle, const std::vector& inputs, const std::vector& outputs) { auto r = c10::make_intrusive( - device, rank, opType, std::optional>(inputs)); + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); return r; } @@ -320,6 +286,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( usedDeviceIdxs_.insert(device.index()); { + // todo: why do we need mutex here? std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { return devXCCLCommMap_[deviceKey]; @@ -352,20 +319,23 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } c10::impl::VirtualGuardImpl impl(device.type()); - c10::Stream stream = impl.getStream(device); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); - XCCLComm = std::make_shared(std::move(comms[0])); + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); - { - std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, XCCLComm); - } + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); for (const auto i : c10::irange(xcclActiveGroupCounter_)) { (void)i; @@ -510,8 +480,8 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( template c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, @@ -563,27 +533,23 @@ c10::intrusive_ptr ProcessGroupXCCL::collectiveCoalesced( } } - auto stream = xcclStreams_.at(key); - syncStream(device, xcclEvents_[key], stream); + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - + work = initWork(device, rank_, opType, profilingTitle); work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); - - { - AutoXcclGroup xccl_group_guard; - for (const auto i : c10::irange(inputs.size())) { - c10::xpu::XPUCachingAllocator::recordStream( - inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], *comm, stream); - } + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); } + post(stream, work); work->xcclEndEvent_->record(stream); - std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -973,8 +939,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type()); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), @@ -984,7 +950,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( ccl::create_stream(stream.queue())); return ret_evt; }, - OpType::ALLREDUCE); + OpType::ALLREDUCE, + "xccl:all_reduce"); } c10::intrusive_ptr ProcessGroupXCCL::allreduce( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 88e87ed6713dd..0f2b2738a4b77 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -1,33 +1,24 @@ #pragma once -#if defined(__linux__) -#include -#include -#include -#include -#endif - #ifdef USE_C10D_XCCL -#include +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + #include -#include #include -#include -#include - -#include -#include #include -#include #include #include -#include #include +#include +#include #include #include #include -#include +#include namespace c10d { namespace { @@ -42,7 +33,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -53,17 +43,14 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Device& device, int rank, OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, const std::optional>& inputs = std::nullopt); WorkXCCL(const WorkXCCL& w); ~WorkXCCL() override; bool isCompleted() override; - bool isSuccess() const override { - TORCH_CHECK( - false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); - } - void abort() override { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); } @@ -76,6 +63,10 @@ class TORCH_API ProcessGroupXCCL : public Backend { return future_; } + uint64_t getSequencenumber() const override { + return seq_; + } + std::vector result() override { return *outputs_; } @@ -86,6 +77,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor barrierTensor_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; + uint64_t seq_; private: void synchronizeInternal(std::chrono::milliseconds timeout); @@ -126,6 +118,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Device& device, int rank, OpType opType, + const char* profilingTitle = nullptr, const std::vector& inputs = {}, const std::vector& outputs = {}); @@ -134,16 +127,30 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType); + } template c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr); template c10::intrusive_ptr collective( @@ -218,13 +225,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& inputbuffer, const AllgatherOptions& opts = AllgatherOptions()) override; - c10::intrusive_ptr allgather_coalesced( - std::vector>& outputTensorLists, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather_coalesced not implemented"); - } - c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& outputs, std::vector& inputs, @@ -298,10 +298,13 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; static thread_local uint64_t xcclActiveGroupCounter_; + uint64_t seqCollective_{0}; + private: - XCCL_KVS kvs; std::mutex kvs_mutex; - XCCL_KVS get_kvs(int rank, c10d::Store& store) { + ccl::shared_ptr_class kvs; + + ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { std::lock_guard lock(kvs_mutex); if (kvs) return kvs; @@ -326,41 +329,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { return kvs; } }; - -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} - -} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 73e37e0437c45..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,7 +557,7 @@ size_t computeLengthsAndOffsets( return offset; } -inline std::string reduce_op_to_string(c10d::ReduceOp op) { +inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: return "SUM"; diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4bbb1c4101123..fc4ca55dbd023 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1675,13 +1675,9 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and ( - device_id.index is None - or (device_id.type != "cuda" and device_id.type != "xpu") - ): + if device_id is not None and device_id.index is None: raise ValueError( - "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, xpu, not just cuda or xpu or cpu" + "init_process_group device_id parameter must be a device with an index" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9ec38c9ca671c..3e1664690b713 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -93,8 +93,9 @@ class DistTestCases: # Sets showing that something is implemented backend_feature = {} - backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"} backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["cuda"] = {"xccl"} backend_feature["ddp"] = {"nccl", "gloo", "ucc"} backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} backend_feature["plugin"] = set() @@ -462,6 +463,15 @@ def compute_sum(fn, world_size: int): ] ] +# Returns the number of GPUs, currently only for CUDA and XPU. +def get_device_count(backend: str): + assert c10d.is_backend_available(backend) + if backend in backend_feature.get("cuda", set()): + return torch.cuda.device_count() + elif backend in backend_feature.get("xpu", set()): + return torch.xpu.device_count() + else: + raise ValueError(f"Unsupported backend: {backend}") # HELPER FOR MULTIGPU TESTS def init_multigpu_helper(world_size: int, backend: str): @@ -470,7 +480,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() + nGPUs = get_device_count(backend) visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's