From 652da01588ecb64e354707bc52496c28c01f07ce Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 29 Aug 2024 09:28:58 +0000 Subject: [PATCH 01/30] Xccl process group for Pytorch --- CMakeLists.txt | 6 + build_variables.bzl | 4 + caffe2/CMakeLists.txt | 13 + caffe2/core/macros.h.in | 1 + cmake/Dependencies.cmake | 16 + cmake/External/xccl.cmake | 17 + cmake/Modules/FindXCCL.cmake | 68 +++ cmake/Summary.cmake | 6 + setup.py | 4 + test/distributed/test_c10d_common.py | 9 +- test/distributed/test_c10d_xccl.py | 303 +++++++++++++ torch/CMakeLists.txt | 7 + torch/_C/_distributed_c10d.pyi | 9 + torch/csrc/distributed/c10d/Ops.cpp | 20 + torch/csrc/distributed/c10d/ProcessGroup.cpp | 2 + torch/csrc/distributed/c10d/ProcessGroup.hpp | 3 + .../distributed/c10d/ProcessGroupXCCL.cpp | 401 ++++++++++++++++++ .../distributed/c10d/ProcessGroupXCCL.hpp | 308 ++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 22 + torch/distributed/distributed_c10d.py | 48 ++- torch/testing/_internal/common_distributed.py | 11 +- 21 files changed, 1268 insertions(+), 10 deletions(-) create mode 100644 cmake/External/xccl.cmake create mode 100644 cmake/Modules/FindXCCL.cmake create mode 100644 test/distributed/test_c10d_xccl.py create mode 100644 torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp create mode 100644 torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5139c0a478e78..89ef59681bfff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,6 +275,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_XCCL "Use XCCL" ON + "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" @@ -353,6 +355,8 @@ cmake_dependent_option(USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option(USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option(USE_C10D_XCCL "USE C10D XCCL" ON + "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option(USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( @@ -365,6 +369,8 @@ cmake_dependent_option( USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option( USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option( + USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option( USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( diff --git a/build_variables.bzl b/build_variables.bzl index e05c94bd83f57..98b721617b609 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -700,6 +700,10 @@ libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_s "torch/csrc/cuda/nccl.cpp", ] +libtorch_xpu_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp", +] + torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA "torch/csrc/api/src/data/datasets/mnist.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 8ed93cdff0479..d44a8da210462 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1014,6 +1014,9 @@ elseif(USE_CUDA) endif() if(USE_XPU) + if(USE_XCCL) + append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS) + endif() add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake target_compile_definitions(torch_xpu PRIVATE USE_XPU) @@ -1079,6 +1082,10 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() + if(USE_XCCL) + target_link_libraries(torch_xpu PRIVATE torch::xccl) + target_compile_definitions(torch_xpu PRIVATE USE_XCCL) + endif() endif() if(NOT MSVC AND USE_XNNPACK) @@ -1365,6 +1372,12 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() endif() + if(USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp + PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL") + endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties( diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 2929f105b31fa..e5398a83cad94 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -45,6 +45,7 @@ {"USE_CUDNN", "${USE_CUDNN}"}, \ {"CUDNN_VERSION", "${CUDNN_VERSION}"}, \ {"USE_NCCL", "${USE_NCCL}"}, \ + {"USE_XCCL", "${USE_XCCL}"}, \ {"USE_MPI", "${USE_MPI}"}, \ {"USE_GFLAGS", "${USE_GFLAGS}"}, \ {"USE_GLOG", "${USE_GLOG}"}, \ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ef33a3165340c..8abea841fcf61 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1150,6 +1150,22 @@ if(USE_CUDA) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() +# ---[ XCCL +if(USE_XCCL) + if(NOT USE_XPU) + message(WARNING + "Not using XPU, so disabling USE_XCCL. Suppress this warning with " + "-DUSE_XCCL=OFF.") + caffe2_update_option(USE_XCCL OFF) + elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(WARNING "USE_XCCL is currently only supported under Linux.") + caffe2_update_option(USE_XCCL OFF) + else() + include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) + list(APPEND Caffe2_XPU_DEPENDENCY_LIBS torch::xccl) + endif() +endif() + if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake new file mode 100644 index 0000000000000..56205b381b132 --- /dev/null +++ b/cmake/External/xccl.cmake @@ -0,0 +1,17 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + if(USE_XCCL) + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + endif() + endif() +endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 0000000000000..56b7fc0f7dcf3 --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,68 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(FindPackageHandleStandardArgs) + +set(XCCL_ROOT "") +if(DEFINED ENV{CCL_ROOT}) + set(XCCL_ROOT $ENV{CCL_ROOT}) +endif() + +string(COMPARE EQUAL "${XCCL_ROOT}" "" nosyclfound) +if(nosyclfound) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "XCCL library not set!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "XCCL library is incomplete!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" +) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index d51c451589c2c..229ff112ab318 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -153,6 +153,12 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") endif() message(STATUS " USE_ITT : ${USE_ITT}") + message(STATUS " USE_XCCL : ${USE_XCCL}") + if(${USE_XCCL}) + message(STATUS " USE_C10D_XCCL : ${USE_C10D_XCCL}") + message(STATUS " XCCL include path : ${XCCL_INCLUDE_DIR}") + message(STATUS " XCCL library : ${XCCL_LIBRARY}") + endif() message(STATUS " USE_NCCL : ${USE_NCCL}") if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") diff --git a/setup.py b/setup.py index 92f1e2ddc7bcd..ad48f4b010863 100644 --- a/setup.py +++ b/setup.py @@ -645,6 +645,10 @@ def run(self): report("-- Building NCCL library") else: report("-- Not using NCCL") + if cmake_cache_vars["USE_XCCL"]: + report("-- Building XCCL library") + else: + report("-- Not using XCCL") if cmake_cache_vars["USE_DISTRIBUTED"]: if IS_WINDOWS: report("-- Building without distributed package") diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 6a0621f3f4991..3e5538d57e38a 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -66,8 +66,13 @@ def gpus_for_rank(world_size): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + device_count = ( + torch.xpu.device_count() + if torch.xpu.is_available() + else torch.cuda.device_count() + ) + visible_devices = list(range(device_count)) + gpus_per_process = device_count // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( diff --git a/test/distributed/test_c10d_xccl.py b/test/distributed/test_c10d_xccl.py new file mode 100644 index 0000000000000..704cdd414e554 --- /dev/null +++ b/test/distributed/test_c10d_xccl.py @@ -0,0 +1,303 @@ +# Owner(s): ["oncall: distributed"] + +import math +import os +import sys +import time +from datetime import timedelta +from unittest import mock + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +import test_c10d_common + +import torch.distributed as dist +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcessTestCase, + requires_xccl, +) +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + TEST_XPU, + TestCase, +) + + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + + +TEST_MULTIXPU = torch.xpu.device_count() > 1 + + +class RendezvousEnvTest(TestCase): + @retry_on_connect_failures + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend="xccl", rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend="xccl") + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + + +class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): + @requires_xccl() + @retry_on_connect_failures + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_default_store_timeout_nccl(self): + self._test_default_store_timeout("xccl") + + +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + def test_set_process_group_desc(self): + device = torch.device(f"xpu:{self.rank}") + pg_default = self._create_process_group_xccl(device_id=device) + self.assertEqual(pg_default.group_desc, "default_pg") + pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + self.assertEqual(pg_1.group_desc, "test_purpose") + pg_2 = c10d.new_group([0, 1]) + self.assertEqual(pg_2.group_desc, "undefined") + + def _test_allreduce_basics(self, fn): + pg = self._create_process_group_xccl() + device = torch.device("xpu:" + str(self.rank)) + # Single input tests + tests = simple_reduce_tests(self.rank, self.world_size) + for op, input, expected in tests: + opts = c10d.AllreduceOptions() + opts.reduceOp = op + tensor = fn(input.to(device)) + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result = fut.value() + self.assertEqual(expected, result[0], exact_dtype=False) + + x = fn(torch.tensor([self.rank + 1.0], device=device)) + fut = pg.allreduce(x).get_future() + fut.wait() + result = fut.value() + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + result[0], + ) + + @requires_xccl() + def test_allreduce_basics(self): + self._test_allreduce_basics(lambda t: t.clone()) + + +if __name__ == "__main__": + assert ( + not torch.xpu._initialized + ), "test_distributed must not have initialized XPU context on main process" + + run_tests() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index bb949a081c95e..9a91b26d54cfb 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -282,6 +282,9 @@ if(USE_DISTRIBUTED) if(USE_NCCL) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) endif() + if(USE_XCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::xccl) + endif() # Same for MPI. if(USE_MPI) list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) @@ -345,6 +348,10 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_python PRIVATE USE_C10D_XCCL) + endif() + if(USE_DISTRIBUTED) target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) endif() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 94e8578bbfff6..6033d96992597 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -309,6 +309,7 @@ class ProcessGroup: UNDEFINED = ... GLOO = ... NCCL = ... + XCCL = ... UCC = ... MPI = ... CUSTOM = ... @@ -697,3 +698,11 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ae822ad397504..699c54236f641 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -79,6 +79,7 @@ namespace { } IMPL_SEND(CPU) +IMPL_SEND(XPU) IMPL_SEND(CUDA) IMPL_SEND(PrivateUse1) @@ -94,6 +95,7 @@ IMPL_SEND(PrivateUse1) } IMPL_RECV(CPU) +IMPL_RECV(XPU) IMPL_RECV(CUDA) IMPL_RECV(PrivateUse1) @@ -108,6 +110,7 @@ IMPL_RECV(PrivateUse1) } IMPL_RECV_ANY_SOURCE(CPU) +IMPL_RECV_ANY_SOURCE(XPU) IMPL_RECV_ANY_SOURCE(CUDA) IMPL_RECV_ANY_SOURCE(PrivateUse1) @@ -131,6 +134,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) } IMPL_REDUCE(CPU) +IMPL_REDUCE(XPU) IMPL_REDUCE(CUDA) IMPL_REDUCE(PrivateUse1) @@ -156,6 +160,7 @@ IMPL_REDUCE(PrivateUse1) } IMPL_BROADCAST(CPU) +IMPL_BROADCAST(XPU) IMPL_BROADCAST(CUDA) IMPL_BROADCAST(PrivateUse1) @@ -181,6 +186,7 @@ IMPL_BROADCAST(PrivateUse1) IMPL_ALLREDUCE(CPU) IMPL_ALLREDUCE(CUDA) +IMPL_ALLREDUCE(XPU) IMPL_ALLREDUCE(PrivateUse1) #define IMPL_ALLREDUCE_COALESCED(DEV) \ @@ -198,6 +204,7 @@ IMPL_ALLREDUCE(PrivateUse1) } IMPL_ALLREDUCE_COALESCED(CPU) +IMPL_ALLREDUCE_COALESCED(XPU) IMPL_ALLREDUCE_COALESCED(CUDA) IMPL_ALLREDUCE_COALESCED(PrivateUse1) @@ -222,6 +229,7 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast) IMPL_ALLGATHER(CPU) +IMPL_ALLGATHER(XPU) IMPL_ALLGATHER(CUDA) IMPL_ALLGATHER(PrivateUse1) @@ -242,6 +250,7 @@ IMPL_ALLGATHER(PrivateUse1) } IMPL__ALLGATHER_BASE(CPU) +IMPL__ALLGATHER_BASE(XPU) IMPL__ALLGATHER_BASE(CUDA) IMPL__ALLGATHER_BASE(PrivateUse1) @@ -258,6 +267,7 @@ IMPL__ALLGATHER_BASE(PrivateUse1) } IMPL_ALLGATHER_COALESCED(CPU) +IMPL_ALLGATHER_COALESCED(XPU) IMPL_ALLGATHER_COALESCED(CUDA) IMPL_ALLGATHER_COALESCED(PrivateUse1) @@ -273,6 +283,7 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) +IMPL_ALLGATHER_INTO_TENSOR_COALESCED(XPU) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) @@ -296,6 +307,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) } IMPL_REDUCE_SCATTER(CPU) +IMPL_REDUCE_SCATTER(XPU) IMPL_REDUCE_SCATTER(CUDA) IMPL_REDUCE_SCATTER(PrivateUse1) @@ -320,6 +332,7 @@ IMPL_REDUCE_SCATTER(PrivateUse1) } IMPL__REDUCE_SCATTER_BASE(CPU) +IMPL__REDUCE_SCATTER_BASE(XPU) IMPL__REDUCE_SCATTER_BASE(CUDA) IMPL__REDUCE_SCATTER_BASE(PrivateUse1) @@ -341,6 +354,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) +IMPL_REDUCE_SCATTER_TENSOR_COALESCED(XPU) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) @@ -360,6 +374,7 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) } IMPL_GATHER(CPU) +IMPL_GATHER(XPU) IMPL_GATHER(CUDA) IMPL_GATHER(PrivateUse1) @@ -382,6 +397,7 @@ IMPL_GATHER(PrivateUse1) } IMPL_SCATTER(CPU) +IMPL_SCATTER(XPU) IMPL_SCATTER(CUDA) IMPL_SCATTER(PrivateUse1) @@ -403,6 +419,7 @@ IMPL_SCATTER(PrivateUse1) } IMPL_ALLTOALL(CPU) +IMPL_ALLTOALL(XPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) @@ -424,6 +441,7 @@ IMPL_ALLTOALL(PrivateUse1) } IMPL_ALLTOALL_BASE(CPU) +IMPL_ALLTOALL_BASE(XPU) IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) @@ -439,6 +457,7 @@ IMPL_ALLTOALL_BASE(PrivateUse1) } IMPL_BARRIER(CPU) +IMPL_BARRIER(XPU) IMPL_BARRIER(CUDA) IMPL_BARRIER(PrivateUse1) // NOLINTEND(cppcoreguidelines-pro-type-const-cast) @@ -491,6 +510,7 @@ namespace { #define REGISTER_C10D_OP(FUNC) \ REGISTER_C10D_OP1(FUNC, CPU) \ REGISTER_C10D_OP1(FUNC, CUDA) \ + REGISTER_C10D_OP1(FUNC, XPU) \ REGISTER_C10D_OP1(FUNC, PrivateUse1) // Now we start to register ops with the three device keys diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 75635bc68aed4..70356b3bf382c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -21,6 +21,8 @@ static ProcessGroup::BackendType strToBackendType(std::string_view backend) { return ProcessGroup::BackendType::GLOO; } else if (backend == "nccl") { return ProcessGroup::BackendType::NCCL; + } else if (backend == "xccl") { + return ProcessGroup::BackendType::XCCL; } else if (backend == "ucc") { return ProcessGroup::BackendType::UCC; } else if (backend == "mpi") { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index acf8c9c354a76..73fc2bda70132 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -70,6 +70,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { UCC = 3, MPI = 4, CUSTOM = 5, + XCCL = 6, }; // Not used, set for backwards compatibility and only used for TypeDef in @@ -489,6 +490,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { getDefaultBackend()->setSequenceNumberForGroup(); } else { @@ -510,6 +512,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp new file mode 100644 index 0000000000000..5aeeb62bee1ec --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -0,0 +1,401 @@ +#include +#include +#include +#include + +#ifdef USE_C10D_XCCL +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10d { + +namespace { +std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, +}; + +XCCL_KVS kvs; +std::mutex kvs_mutex; + +XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + + return kvs; +} + +void check_xpu_single_tensor(const at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse()) { + C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); + } + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } +} + +ccl::datatype getXcclDataType(at::ScalarType type) { + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool) { + if (reduceOp == ReduceOp::SUM) { + // For bool tensors, map sum to max, which both represent a bitwise or. + // This is to prevent overflow issues with sum, since we use uint8 to + // represent a bool (see xcclDatatypes mapping align with cuda). + return ccl::reduction::max; + } + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + switch (reduceOp) { + case ReduceOp::AVG: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL"); + break; + case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL"); + break; + case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL"); + break; + case ReduceOp::BXOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL"); + break; + default: + C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); + break; + } + } +} + +} // namespace + +static std::mutex xcclCommDevIdxMapMutex; +static std::unordered_map, int> xcclCommDevIdxMap; +constexpr int64_t kSynchronizeBusyWaitMillis = 10; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + const std::optional>& inputs) + : Work(rank, opType, "profilingTitle", inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()) { + unsigned char enable_timing = 0; + xcclEndEvent_ = std::make_shared(enable_timing); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::checkTimeout( + std::optional timeout) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); + + auto workTimeout = timeout ? *timeout : opTimeout; + + if (timeElapsed < workTimeout) + return false; + return true; +} + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + // Block the current stream on the XCCL stream + xcclEndEvent_->block(currentStream); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + synchronizeStream(); + + if (blockingWait_) { + while (!isCompleted()) { + bool timedOut = checkTimeout( + timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + if (timedOut) { + break; + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); + + // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI + // launchers. + if (!with_mpirun()) { + int local_rank = getXCCLEnvVar("LOCAL_RANK"); + int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); + if (local_rank == -1 || local_world_size == -1) { + local_rank = rank; + local_world_size = size; + } + setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); + setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); + setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); + } +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, rank, opType, std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device) { + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the XCCL Communicator since " + "the devices are empty "); + } + + { + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + std::shared_ptr XCCLComm; + + XCCL_KVS kvs = get_kvs(rank_, *store_); + + int numRanks, rank; + numRanks = getSize(); + rank = getRank(); + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = impl.getStream(device); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); + XCCLComm = std::make_shared(std::move(comms[0])); + + { + std::lock_guard lock(mutex_); + inInitializationCommMap_.emplace(deviceKey, XCCLComm); + } + + xcclStreams_.emplace(deviceKey, std::move(stream)); + + auto it = inInitializationCommMap_.find(deviceKey); + if (it != inInitializationCommMap_.end()) { + devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); + inInitializationCommMap_.erase(deviceKey); + + xcclCommDevIdxMapMutex.lock(); + xcclCommDevIdxMap.emplace(XCCLComm, device.index()); + xcclCommDevIdxMapMutex.unlock(); + } + + it = devXCCLCommMap_.find(deviceKey); + TORCH_INTERNAL_ASSERT( + it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); + + return it->second; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType) { + using traits = function_traits; + using attr_t = typename traits::template arg<2>::type; + attr_t attr = ccl::create_operation_attr(); + + auto device = input.device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + auto stream = xcclStreams_.at(key); + std::vector outputs{output}; + + c10::intrusive_ptr work; + + work = initWork(device, rank_, opType); + + work->outputs_ = + std::make_shared>(std::move(outputs)); + c10::xpu::XPUCachingAllocator::recordStream( + input.storage().data_ptr(), stream); + + auto ccl_stream = ccl::create_stream(stream.queue()); + + fn(input, output, attr, *comm, ccl_stream); + + work->xcclEndEvent_->record(stream); + + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + opType); +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); + auto tensor = tensors.back(); + check_xpu_single_tensor(tensor); + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + ccl::allreduce_attr attr, + xcclComm_t& comm, + ccl::stream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::event ret_evt; + ret_evt = ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + stream, + attr); + return ret_evt; + }, + OpType::ALLREDUCE); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp new file mode 100644 index 0000000000000..14a9f398a8cbe --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -0,0 +1,308 @@ +#pragma once + +#if defined(__linux__) +#include +#include +#include +#include +#endif + +#ifdef USE_C10D_XCCL +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} +} // namespace + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +using XCCL_KVS = ccl::shared_ptr_class; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + bool isSuccess() const override { + TORCH_CHECK( + false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); + } + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + void synchronizeStream(); + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + std::vector result() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); + } + + bool checkTimeout( + std::optional timeout = std::nullopt); + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); + } + + c10::intrusive_ptr reduce( + std::vector& tensors, + const ReduceOptions& opts = ReduceOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::reduce not implemented"); + } + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); + } + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allgather not implemented"); + } + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::_allgather_base not implemented"); + } + + c10::intrusive_ptr allgather_coalesced( + std::vector>& outputTensorLists, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::allgather_coalesced not implemented"); + } + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override { + TORCH_CHECK( + false, + "ProcessGroupXCCL::allgather_into_tensor_coalesced not implemented"); + } + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); + } + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK( + false, "ProcessGroupXCCL::_reduce_scatter_base not implemented"); + } + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override { + TORCH_CHECK( + false, + "ProcessGroupXCCL::reduce_scatter_tensor_coalesced not implemented"); + } + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::barrier not implemented"); + } + + c10::intrusive_ptr alltoall_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + std::vector& outputSplitSizes, + std::vector& inputSplitSizes, + const AllToAllOptions& opts = AllToAllOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::alltoall_base not implemented"); + } + + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::alltoall not implemented"); + } + + c10::intrusive_ptr send( + std::vector& tensors, + int dstRank, + int tag) override { + TORCH_CHECK(false, "ProcessGroupXCCL::send not implemented"); + } + + c10::intrusive_ptr recv( + std::vector& tensors, + int srcRank, + int tag) override { + TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); + } + + c10::intrusive_ptr gather( + std::vector>& outputTensors, + std::vector& inputTensors, + const GatherOptions& opts = GatherOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::gather not implemented"); + } + + c10::intrusive_ptr scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ScatterOptions& opts = ScatterOptions()) override { + TORCH_CHECK(false, "ProcessGroupXCCL::scatter not implemented"); + } + + protected: + std::unordered_map xcclStreams_; + std::unordered_map> + inInitializationCommMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + std::mutex mutex_; + bool blockingWait_ = false; +}; +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index c8f9dff37f06e..e3ed6d6bd4bcb 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -37,6 +37,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #include #include #include @@ -2232,6 +2236,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) @@ -2877,6 +2882,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + py::enum_<::c10d::OpType>(module, "OpType") .value("BROADCAST", ::c10d::OpType::BROADCAST) .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 45e096985143a..9fa3224873c9f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -87,6 +87,7 @@ "is_nccl_available", "is_torchelastic_launched", "is_ucc_available", + "is_xccl_available", "isend", "monitored_barrier", "new_group", @@ -130,6 +131,7 @@ _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True _UCC_AVAILABLE = True +_XCCL_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -193,6 +195,14 @@ def _export_c_types() -> None: except ImportError: _UCC_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + logger = logging.getLogger(__name__) PG_WRAPPER_STORE_PREFIX = "pg_wrapper" @@ -222,7 +232,7 @@ class Backend(str): """ An enum-like class for backends. - Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can be accessed as attributes, e.g., ``Backend.NCCL``. @@ -242,21 +252,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -265,6 +278,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -1098,6 +1112,11 @@ def is_ucc_available() -> bool: return _UCC_AVAILABLE +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + def is_backend_available(backend: str) -> bool: """ Check backend availability. @@ -1350,6 +1369,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xccl_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: @@ -1385,7 +1408,7 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on - build-time configurations, valid values include ``mpi``, ``gloo``, + build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``, ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` and ``nccl`` backend will be created, see notes below for how multiple backends are managed. This field can be given as a lowercase string @@ -1651,10 +1674,13 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and ( + device_id.index is None + or (device_id.type != "cuda" and device_id.type != "xpu") + ): raise ValueError( "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "id, e.g. cuda:0, xpu, not just cuda or xpu or cpu" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value @@ -1762,7 +1788,6 @@ def _new_process_group_helper( pg_options = ProcessGroupNCCL.Options() pg_options.is_high_priority_stream = False pg_options._timeout = timeout - if split_from: pg_options.split_from = split_from pg_options.split_color = _process_group_color(global_ranks_in_group) @@ -1781,6 +1806,17 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, timeout=timeout ) backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + if pg_options is not None: + assert isinstance( + pg_options, ProcessGroupXCCL.Options + ), "Expected pg_options argument to be of type ProcessGroupXCCL.Options" + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL else: assert ( backend_str.upper() in Backend._plugins diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index d59102232f7db..26bdcce610312 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -180,7 +180,8 @@ def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: + if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ + (torch.xpu.is_available() and torch.xpu.device_count() >= x): return func(*args, **kwargs) sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) @@ -320,6 +321,12 @@ def requires_nccl(): "c10d was not compiled with the NCCL backend", ) +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + def requires_ucc(): return skip_but_pass_in_sandcastle_if( not c10d.is_ucc_available(), @@ -463,7 +470,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.cuda.device_count() + nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's From a71d69a50684d8e6c6edd2ddc285f7589a44914d Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 20 Sep 2024 03:39:00 +0000 Subject: [PATCH 02/30] Align latest --- torch/distributed/distributed_c10d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2d9357bbd15a4..4bbb1c4101123 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1819,10 +1819,10 @@ def _new_process_group_helper( elif backend_str == Backend.XCCL: if not is_xccl_available(): raise RuntimeError("Distributed package doesn't have XCCL built in") - if pg_options is not None: + if backend_options is not None: assert isinstance( - pg_options, ProcessGroupXCCL.Options - ), "Expected pg_options argument to be of type ProcessGroupXCCL.Options" + backend_options, ProcessGroupXCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupXCCL.Options" backend_class = ProcessGroupXCCL( backend_prefix_store, group_rank, group_size ) From af6f03c3a1144c1ad4440323051d88d164854fa1 Mon Sep 17 00:00:00 2001 From: hanchao Date: Tue, 24 Sep 2024 05:19:22 +0000 Subject: [PATCH 03/30] hidden env --- .../distributed/c10d/ProcessGroupXCCL.cpp | 30 ------ .../distributed/c10d/ProcessGroupXCCL.hpp | 98 ++++++++++++------- 2 files changed, 64 insertions(+), 64 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 5aeeb62bee1ec..8aca359405011 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -45,36 +45,6 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - void check_xpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 14a9f398a8cbe..521d35bcf67fd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -31,40 +31,6 @@ #include namespace c10d { -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} -} // namespace - static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -303,6 +269,70 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::mutex mutex_; bool blockingWait_ = false; }; + +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} + +XCCL_KVS kvs; +std::mutex kvs_mutex; +XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + + return kvs; +} + +} // namespace } // namespace c10d #endif // USE_C10D_XCCL From 88bea257af5931dc9e083f1cf7027e802d587f8d Mon Sep 17 00:00:00 2001 From: hanchao Date: Sun, 29 Sep 2024 09:05:17 +0000 Subject: [PATCH 04/30] refine findccl code --- cmake/Modules/FindXCCL.cmake | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 56b7fc0f7dcf3..a717ad1dafc65 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -11,10 +11,10 @@ if(DEFINED ENV{CCL_ROOT}) set(XCCL_ROOT $ENV{CCL_ROOT}) endif() -string(COMPARE EQUAL "${XCCL_ROOT}" "" nosyclfound) -if(nosyclfound) +string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) +if(nocclfound) set(XCCL_FOUND False) - set(XCCL_REASON_FAILURE "XCCL library not set!!") + set(XCCL_REASON_FAILURE "OneCCL library not found!!") set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") return() endif() @@ -55,7 +55,7 @@ find_library( if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) set(XCCL_FOUND False) - set(XCCL_REASON_FAILURE "XCCL library is incomplete!!") + set(XCCL_REASON_FAILURE "OneCCL library not found!!") set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") return() endif() From f6ea93450c1b8bd10b709ee5b4076ac25c6413a2 Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 30 Sep 2024 04:41:54 +0000 Subject: [PATCH 05/30] Add comments for build xccl --- caffe2/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 2160399a3ea29..54ec7db0cad87 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1013,6 +1013,10 @@ elseif(USE_CUDA) endif() if(USE_XPU) + # if SYCL runtime and oneCCL runtime are both system installed + # then building flag USE_XPU=ON , USE_XCCL=ON and USE_C10D_XCCL=ON; + # XCCL backend will be build in libtorch_xpu; + # manually set `USE_XCCL=OFF` disable XCCL backend building. if(USE_XCCL) append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS) endif() From 1226e3b0f2397fdff2b78348b50c079a065cfc10 Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 30 Sep 2024 05:04:56 +0000 Subject: [PATCH 06/30] refine workxccl --- .../distributed/c10d/ProcessGroupXCCL.cpp | 53 +++++++++---------- .../distributed/c10d/ProcessGroupXCCL.hpp | 5 -- 2 files changed, 25 insertions(+), 33 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 8aca359405011..62b793b54a10a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -123,20 +123,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; -bool ProcessGroupXCCL::WorkXCCL::checkTimeout( - std::optional timeout) { - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - workStartTime_); - std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); - - auto workTimeout = timeout ? *timeout : opTimeout; - - if (timeElapsed < workTimeout) - return false; - return true; -} - bool ProcessGroupXCCL::WorkXCCL::isCompleted() { if (xcclEndEvent_ && xcclEndEvent_->query()) { return true; @@ -145,23 +131,17 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { } void ProcessGroupXCCL::WorkXCCL::synchronize() { - synchronizeInternal(kNoTimeout); -} - -void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - // Block the current stream on the XCCL stream xcclEndEvent_->block(currentStream); -} - -void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( - std::chrono::milliseconds timeout) { - synchronizeStream(); - if (blockingWait_) { while (!isCompleted()) { - bool timedOut = checkTimeout( - timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); + bool timedOut = true; + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); + if (timeElapsed < opTimeout) + timedOut = false; if (timedOut) { break; } @@ -172,7 +152,24 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { - synchronizeInternal(timeout); + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); + if (blockingWait_) { + while (!isCompleted()) { + bool timedOut = true; + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); + if (timeElapsed < timeout) + timedOut = false; + if (timedOut) { + break; + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } return true; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 521d35bcf67fd..8160ca18e99ba 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -64,8 +64,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { void synchronize() override; - void synchronizeStream(); - bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; c10::intrusive_ptr getFuture() override { @@ -76,9 +74,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); } - bool checkTimeout( - std::optional timeout = std::nullopt); - protected: at::Device device_; std::shared_ptr xcclEndEvent_; From d62e0be003512c7f636dc0a11e95696fe982d849 Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 30 Sep 2024 06:12:01 +0000 Subject: [PATCH 07/30] refine timeout --- .../distributed/c10d/ProcessGroupXCCL.cpp | 36 +++++++------------ 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 62b793b54a10a..173d55dfdbcc8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -131,20 +131,27 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { } void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); xcclEndEvent_->block(currentStream); if (blockingWait_) { while (!isCompleted()) { - bool timedOut = true; auto currentTimepoint = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast( currentTimepoint - workStartTime_); std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); - if (timeElapsed < opTimeout) - timedOut = false; - if (timedOut) { - break; + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + TORCH_CHECK(false, exceptionMsg) } + std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -152,24 +159,7 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() { } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { - auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - xcclEndEvent_->block(currentStream); - if (blockingWait_) { - while (!isCompleted()) { - bool timedOut = true; - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - workStartTime_); - std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); - if (timeElapsed < timeout) - timedOut = false; - if (timedOut) { - break; - } - std::this_thread::sleep_for( - std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); - } - } + synchronizeInternal(timeout); return true; } From 714de2a1af1090a2c102921a82cda727b8ad7928 Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 30 Sep 2024 06:33:38 +0000 Subject: [PATCH 08/30] rm head --- torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp | 16 +++------------- torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp | 1 - 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 173d55dfdbcc8..fb9b21f1f63e7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,11 +1,9 @@ +#ifdef USE_C10D_XCCL + #include #include -#include -#include - -#ifdef USE_C10D_XCCL -#include #include +#include #include #include #include @@ -13,15 +11,7 @@ #include #include -#include -#include -#include #include -#include -#include -#include -#include -#include namespace c10d { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 8160ca18e99ba..232d0200ae3cf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -28,7 +28,6 @@ #include #include #include -#include namespace c10d { static std::vector TORCH_XCCL_BLOCKING_WAIT = { From 09237819b12218271aae1a5f7ec9ca47d3aeebed Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 30 Sep 2024 07:19:23 +0000 Subject: [PATCH 09/30] update --- .../distributed/c10d/ProcessGroupNCCL.cpp | 9 ++- .../distributed/c10d/ProcessGroupXCCL.cpp | 21 +++---- .../distributed/c10d/ProcessGroupXCCL.hpp | 57 +++++++++---------- torch/csrc/distributed/c10d/Utils.hpp | 25 ++++++++ 4 files changed, 66 insertions(+), 46 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 8a7aefdc238c4..ff47e2cf21301 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -180,13 +180,12 @@ ncclRedOpRAII getNcclReduceOp( NCCL_MINOR)); break; case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); - break; case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); - break; case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + + " with NCCL"); break; default: C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index fb9b21f1f63e7..0dd022115cc46 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -67,17 +67,15 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { switch (reduceOp) { - case ReduceOp::AVG: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL"); - break; - case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL"); - break; - case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL"); - break; - case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL"); + case c10d::ReduceOp::BAND: + case c10d::ReduceOp::BOR: + case c10d::ReduceOp::BXOR: + case c10d::ReduceOp::AVG: + case c10d::ReduceOp::PREMUL_SUM: + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + + " with XCCL"); break; default: C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); @@ -133,7 +131,6 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( auto currentTimepoint = std::chrono::steady_clock::now(); auto timeElapsed = std::chrono::duration_cast( currentTimepoint - workStartTime_); - std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( "Work ran for ", diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 232d0200ae3cf..99b815f2138b4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -262,6 +262,34 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; std::mutex mutex_; bool blockingWait_ = false; + + private: + XCCL_KVS kvs; + std::mutex kvs_mutex; + XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } }; namespace { @@ -297,35 +325,6 @@ bool with_mpirun() { : false; } -XCCL_KVS kvs; -std::mutex kvs_mutex; -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - } // namespace } // namespace c10d diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..73e37e0437c45 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets( return offset; } +inline std::string reduce_op_to_string(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} + using RankType = uint32_t; using SizeType = uint64_t; From 31d092d72303b08dcdfa0b2fd8b4e4ae45d3dffd Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 9 Oct 2024 08:18:35 +0000 Subject: [PATCH 10/30] minor fix --- caffe2/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 54ec7db0cad87..16a8834225915 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1085,7 +1085,7 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() - if(USE_XCCL) + if(USE_C10D_XCCL) target_link_libraries(torch_xpu PRIVATE torch::xccl) target_compile_definitions(torch_xpu PRIVATE USE_XCCL) endif() From cbea299190dcc0e90796fac38cb7b1adb2a34e1a Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 9 Oct 2024 10:31:25 +0000 Subject: [PATCH 11/30] rm duplicate code and refine cmake --- CMakeLists.txt | 2 -- caffe2/CMakeLists.txt | 4 ++-- cmake/Dependencies.cmake | 1 - cmake/External/xccl.cmake | 3 +++ 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0318fcb4d1ec0..60fc8aae14173 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -369,8 +369,6 @@ cmake_dependent_option( USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option( USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) -cmake_dependent_option( - USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option( USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 16a8834225915..b4ec018019f16 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1085,7 +1085,7 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() - if(USE_C10D_XCCL) + if(USE_XCCL) target_link_libraries(torch_xpu PRIVATE torch::xccl) target_compile_definitions(torch_xpu PRIVATE USE_XCCL) endif() @@ -1374,7 +1374,7 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() endif() - if(USE_C10D_XCCL) + if(USE_XPU AND USE_C10D_XCCL) target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3e59b813d3138..ee38f19773af8 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1163,7 +1163,6 @@ if(USE_XCCL) caffe2_update_option(USE_XCCL OFF) else() include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) - list(APPEND Caffe2_XPU_DEPENDENCY_LIBS torch::xccl) endif() endif() diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake index 56205b381b132..467bb830e0b6c 100644 --- a/cmake/External/xccl.cmake +++ b/cmake/External/xccl.cmake @@ -12,6 +12,9 @@ if(NOT __XCCL_INCLUDED) set_property( TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES ${XCCL_LIBRARY}) + else() + set(USE_XCCL OFF) + set(USE_C10D_XCCL OFF) endif() endif() endif() From ef261c6f3de6d9cd25c12dbf149fca83482996e2 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 10 Oct 2024 00:30:33 +0000 Subject: [PATCH 12/30] update cmake --- cmake/Dependencies.cmake | 3 +++ cmake/External/xccl.cmake | 25 ++++++++++--------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ee38f19773af8..f90846e89c754 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1163,6 +1163,9 @@ if(USE_XCCL) caffe2_update_option(USE_XCCL OFF) else() include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) + if(NOT XCCL_FOUND) + caffe2_update_option(USE_XCCL OFF) + endif() endif() endif() diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake index 467bb830e0b6c..acb7cee87593e 100644 --- a/cmake/External/xccl.cmake +++ b/cmake/External/xccl.cmake @@ -1,20 +1,15 @@ if(NOT __XCCL_INCLUDED) set(__XCCL_INCLUDED TRUE) - if(USE_XCCL) - # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. - find_package(XCCL REQUIRED) - if(XCCL_FOUND) - add_library(torch::xccl INTERFACE IMPORTED) - set_property( - TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES - ${XCCL_INCLUDE_DIR}) - set_property( - TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES - ${XCCL_LIBRARY}) - else() - set(USE_XCCL OFF) - set(USE_C10D_XCCL OFF) - endif() + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) endif() endif() From 6c648cdbc1260f55256eebff0e0a0d6981b66694 Mon Sep 17 00:00:00 2001 From: hanchao Date: Tue, 24 Sep 2024 05:19:22 +0000 Subject: [PATCH 13/30] hidden xccl specific --- .../distributed/c10d/ProcessGroupXCCL.cpp | 104 +++--------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 103 ++++++++++------- torch/csrc/distributed/c10d/Utils.hpp | 25 +++++ 3 files changed, 106 insertions(+), 126 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 5aeeb62bee1ec..d26d25ae03e39 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,11 +1,9 @@ +#ifdef USE_C10D_XCCL + #include #include -#include -#include - -#ifdef USE_C10D_XCCL -#include #include +#include #include #include #include @@ -13,15 +11,7 @@ #include #include -#include -#include -#include #include -#include -#include -#include -#include -#include namespace c10d { @@ -45,36 +35,6 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - void check_xpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); @@ -106,23 +66,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { - switch (reduceOp) { - case ReduceOp::AVG: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL"); - break; - case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL"); - break; - case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL"); - break; - case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL"); - break; - default: - C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); - break; - } + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); } } @@ -153,20 +99,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; -bool ProcessGroupXCCL::WorkXCCL::checkTimeout( - std::optional timeout) { - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - workStartTime_); - std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); - - auto workTimeout = timeout ? *timeout : opTimeout; - - if (timeElapsed < workTimeout) - return false; - return true; -} - bool ProcessGroupXCCL::WorkXCCL::isCompleted() { if (xcclEndEvent_ && xcclEndEvent_->query()) { return true; @@ -178,23 +110,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() { synchronizeInternal(kNoTimeout); } -void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { - auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - // Block the current stream on the XCCL stream - xcclEndEvent_->block(currentStream); -} - void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds timeout) { - synchronizeStream(); - + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); if (blockingWait_) { while (!isCompleted()) { - bool timedOut = checkTimeout( - timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); - if (timedOut) { - break; + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + TORCH_CHECK(false, exceptionMsg) } + std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 14a9f398a8cbe..99b815f2138b4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -28,43 +28,8 @@ #include #include #include -#include namespace c10d { -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} -} // namespace - static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -98,8 +63,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { void synchronize() override; - void synchronizeStream(); - bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; c10::intrusive_ptr getFuture() override { @@ -110,9 +73,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); } - bool checkTimeout( - std::optional timeout = std::nullopt); - protected: at::Device device_; std::shared_ptr xcclEndEvent_; @@ -302,7 +262,70 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; std::mutex mutex_; bool blockingWait_ = false; + + private: + XCCL_KVS kvs; + std::mutex kvs_mutex; + XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } }; + +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} + +} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..73e37e0437c45 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets( return offset; } +inline std::string reduce_op_to_string(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} + using RankType = uint32_t; using SizeType = uint64_t; From e621fe6010382c3c4e614df4ace6a861f598442d Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 11 Oct 2024 01:55:18 +0000 Subject: [PATCH 14/30] fix ci fail --- test/distributed/test_c10d_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index d96abb1ca8267..903df26bba9f6 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1836,6 +1836,9 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + elif backend == dist.Backend.XCCL: + if not dist.is_xccl_available(): + continue # Multi-threaded PG is defined as a pure python class. # Its pg.name() does not going through Pybind, so its backend name # is still "threaded" instead of "custom". From 3f225d9b7eb38be2ca898dd90cfdf67c9543a0cb Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 00:36:44 +0000 Subject: [PATCH 15/30] rm vir fun and modify tensor check --- .../distributed/c10d/ProcessGroupXCCL.cpp | 30 +++-- .../distributed/c10d/ProcessGroupXCCL.hpp | 124 +----------------- 2 files changed, 19 insertions(+), 135 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 0dd022115cc46..d9058b532ec74 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -36,15 +36,21 @@ std::map xcclDatatypes = { }; void check_xpu_single_tensor(const at::Tensor& tensor) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); - } - if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } } } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -58,9 +64,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { if (input.scalar_type() == at::kBool) { if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise or. - // This is to prevent overflow issues with sum, since we use uint8 to - // represent a bool (see xcclDatatypes mapping align with cuda). + // For bool tensors, map sum to max, which both represent a bitwise + // or. This is to prevent overflow issues with sum, since we use uint8 + // to represent a bool (see xcclDatatypes mapping align with cuda). return ccl::reduction::max; } } @@ -227,7 +233,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( inInitializationCommMap_.emplace(deviceKey, XCCLComm); } - xcclStreams_.emplace(deviceKey, std::move(stream)); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); auto it = inInitializationCommMap_.find(deviceKey); if (it != inInitializationCommMap_.end()) { @@ -262,7 +268,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); - auto stream = xcclStreams_.at(key); + auto stream = xcclStreamsMap_.at(key); std::vector outputs{output}; c10::intrusive_ptr work; @@ -323,7 +329,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( ccl::allreduce_attr attr, xcclComm_t& comm, ccl::stream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); ccl::event ret_evt; ret_evt = ccl::allreduce( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 99b815f2138b4..1c4b22cbc8ded 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -132,130 +132,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); - } - - c10::intrusive_ptr reduce( - std::vector& tensors, - const ReduceOptions& opts = ReduceOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce not implemented"); - } - - c10::intrusive_ptr broadcast( - std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); - } - - c10::intrusive_ptr allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather not implemented"); - } - - c10::intrusive_ptr _allgather_base( - at::Tensor& outputbuffer, - at::Tensor& inputbuffer, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::_allgather_base not implemented"); - } - - c10::intrusive_ptr allgather_coalesced( - std::vector>& outputTensorLists, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather_coalesced not implemented"); - } - - c10::intrusive_ptr allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::allgather_into_tensor_coalesced not implemented"); - } - - c10::intrusive_ptr reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); - } - - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, "ProcessGroupXCCL::_reduce_scatter_base not implemented"); - } - - c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::reduce_scatter_tensor_coalesced not implemented"); - } - - c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::barrier not implemented"); - } - - c10::intrusive_ptr alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::alltoall_base not implemented"); - } - - c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::alltoall not implemented"); - } - - c10::intrusive_ptr send( - std::vector& tensors, - int dstRank, - int tag) override { - TORCH_CHECK(false, "ProcessGroupXCCL::send not implemented"); - } - - c10::intrusive_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override { - TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); - } - - c10::intrusive_ptr gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts = GatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::gather not implemented"); - } - - c10::intrusive_ptr scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts = ScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::scatter not implemented"); - } - protected: - std::unordered_map xcclStreams_; + std::unordered_map xcclStreamsMap_; std::unordered_map> inInitializationCommMap_; std::unordered_map> devXCCLCommMap_; From 8e5e78a1cfdad2f767e8ab53e9b5e2a80fb53c7e Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 01:24:06 +0000 Subject: [PATCH 16/30] refine collective, getcomm --- .../distributed/c10d/ProcessGroupXCCL.cpp | 69 ++++++------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 29 ++++++-- 2 files changed, 45 insertions(+), 53 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 5b3bc82cd30b8..b25275757dea2 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,5 +1,6 @@ #ifdef USE_C10D_XCCL +#include #include #include #include @@ -218,32 +219,18 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( { std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, XCCLComm); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); } xcclStreamsMap_.emplace(deviceKey, std::move(stream)); - auto it = inInitializationCommMap_.find(deviceKey); - if (it != inInitializationCommMap_.end()) { - devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - xcclCommDevIdxMapMutex.lock(); - xcclCommDevIdxMap.emplace(XCCLComm, device.index()); - xcclCommDevIdxMapMutex.unlock(); - } - - it = devXCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); - - return it->second; + return XCCLComm; } template c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, @@ -252,25 +239,26 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( using attr_t = typename traits::template arg<2>::type; attr_t attr = ccl::create_operation_attr(); - auto device = input.device(); + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); auto stream = xcclStreamsMap_.at(key); - std::vector outputs{output}; c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - work->outputs_ = - std::make_shared>(std::move(outputs)); - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); + work->outputs_ = std::make_shared>(outputs); - auto ccl_stream = ccl::create_stream(stream.queue()); + at::xpu::OptionalXPUGuard gpuGuard(device); - fn(input, output, attr, *comm, ccl_stream); + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], attr, *comm, stream); + } + post(stream, work); work->xcclEndEvent_->record(stream); @@ -285,23 +273,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - opType); -} - c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { @@ -316,20 +287,20 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( at::Tensor& output, ccl::allreduce_attr attr, xcclComm_t& comm, - ccl::stream& stream) { + at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, comm, - stream, + ccl_stream, attr); - return ret_evt; + return; }, OpType::ALLREDUCE); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 1c4b22cbc8ded..d813428484c41 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -70,7 +70,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { } std::vector result() override { - TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); + return *outputs_; } protected: @@ -117,7 +117,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType); + OpType opType) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType); + } template c10::intrusive_ptr collective( @@ -126,6 +136,19 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, PreProcess pre, PostProcess post, + OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, OpType opType); c10::intrusive_ptr allreduce( @@ -134,8 +157,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: std::unordered_map xcclStreamsMap_; - std::unordered_map> - inInitializationCommMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; From 1267963e0f4c5674fee452ae1ee0a5bbd96fb053 Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 02:23:30 +0000 Subject: [PATCH 17/30] accept comments --- torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp | 12 ++++++++++-- torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp | 6 +----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index b25275757dea2..7287a47990e8c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -79,10 +79,16 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { } } +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} + } // namespace -static std::mutex xcclCommDevIdxMapMutex; -static std::unordered_map, int> xcclCommDevIdxMap; constexpr int64_t kSynchronizeBusyWaitMillis = 10; ProcessGroupXCCL::WorkXCCL::WorkXCCL( @@ -223,6 +229,7 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); return XCCLComm; } @@ -244,6 +251,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( auto comm = getXCCLComm(key, device); auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; work = initWork(device, rank_, opType); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index d813428484c41..41b78035c83f9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -52,11 +52,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { bool isCompleted() override; - bool isSuccess() const override { - TORCH_CHECK( - false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); - } - void abort() override { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); } @@ -157,6 +152,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; From 3d55b856939f63d82753e23f365fd06392ff97f0 Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 02:36:21 +0000 Subject: [PATCH 18/30] rm attr --- torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 7287a47990e8c..caf9aeee5a6bd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include @@ -242,10 +241,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( PreProcess pre, PostProcess post, OpType opType) { - using traits = function_traits; - using attr_t = typename traits::template arg<2>::type; - attr_t attr = ccl::create_operation_attr(); - auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); @@ -264,7 +259,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); - fn(inputs[i], outputs[i], attr, *comm, stream); + fn(inputs[i], outputs[i], *comm, stream); } post(stream, work); @@ -293,7 +288,6 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( tensor, [&](at::Tensor& input, at::Tensor& output, - ccl::allreduce_attr attr, xcclComm_t& comm, at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); @@ -306,8 +300,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - ccl_stream, - attr); + ccl_stream); return; }, OpType::ALLREDUCE); From f69059a627089face0bf7a2b64d7400d83686b9f Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 06:37:15 +0000 Subject: [PATCH 19/30] add default ccl root dir --- cmake/Modules/FindXCCL.cmake | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index a717ad1dafc65..c0c2e1f71c09a 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,8 +6,14 @@ include(FindPackageHandleStandardArgs) +# include FindMKL.cmake to use its variables +include(FindMKL) + set(XCCL_ROOT "") -if(DEFINED ENV{CCL_ROOT}) +if (EXISTS "${DEFAULT_INTEL_COMPILER_DIR}/ccl/latest") + SET(XCCL_ROOT "${DEFAULT_INTEL_COMPILER_DIR}/ccl/latest") +elseif(DEFINED ENV{CCL_ROOT}) + message(STATUS "Default OneCCL not found, using current environment OneCCL") set(XCCL_ROOT $ENV{CCL_ROOT}) endif() From bed720ce361a5522c7e133c200bb750e4a130613 Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 06:40:58 +0000 Subject: [PATCH 20/30] update --- cmake/Modules/FindXCCL.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index c0c2e1f71c09a..424d889ce9760 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -10,8 +10,8 @@ include(FindPackageHandleStandardArgs) include(FindMKL) set(XCCL_ROOT "") -if (EXISTS "${DEFAULT_INTEL_COMPILER_DIR}/ccl/latest") - SET(XCCL_ROOT "${DEFAULT_INTEL_COMPILER_DIR}/ccl/latest") +if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") + SET(XCCL_ROOT "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") elseif(DEFINED ENV{CCL_ROOT}) message(STATUS "Default OneCCL not found, using current environment OneCCL") set(XCCL_ROOT $ENV{CCL_ROOT}) From fd44abeffcd8e4523d97459d6c45a5f3d813c4cc Mon Sep 17 00:00:00 2001 From: hanchao Date: Sat, 12 Oct 2024 06:52:47 +0000 Subject: [PATCH 21/30] update --- cmake/Modules/FindXCCL.cmake | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index 424d889ce9760..c1ac5fff78630 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,13 +6,8 @@ include(FindPackageHandleStandardArgs) -# include FindMKL.cmake to use its variables -include(FindMKL) - -set(XCCL_ROOT "") -if (EXISTS "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") - SET(XCCL_ROOT "${DEFAULT_INTEL_ONEAPI_DIR}/ccl/latest") -elseif(DEFINED ENV{CCL_ROOT}) +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") message(STATUS "Default OneCCL not found, using current environment OneCCL") set(XCCL_ROOT $ENV{CCL_ROOT}) endif() From d12b922d75ab6e92e5f711a98f8c05a31c88e46f Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Sun, 13 Oct 2024 12:47:15 +0800 Subject: [PATCH 22/30] code refine --- .../distributed/c10d/ProcessGroupNCCL.cpp | 9 +-- .../distributed/c10d/ProcessGroupXCCL.cpp | 72 +++++-------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 61 +++------------- torch/csrc/distributed/c10d/Utils.hpp | 2 +- 4 files changed, 33 insertions(+), 111 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ff47e2cf21301..8a7aefdc238c4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -180,12 +180,13 @@ ncclRedOpRAII getNcclReduceOp( NCCL_MINOR)); break; case ReduceOp::BAND: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with NCCL"); + break; case ReduceOp::BOR: + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with NCCL"); + break; case ReduceOp::BXOR: - C10_THROW_ERROR( - ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + - " with NCCL"); + C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); break; default: C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index caf9aeee5a6bd..291a27a53559d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -16,14 +16,14 @@ namespace c10d { namespace { -std::map xcclOps = { +const std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, {ReduceOp::SUM, ccl::reduction::sum}, {ReduceOp::PRODUCT, ccl::reduction::prod}, }; -std::map xcclDatatypes = { +const std::map xcclDatatypes = { {at::kByte, ccl::datatype::uint8}, {at::kChar, ccl::datatype::int8}, {at::kInt, ccl::datatype::int32}, @@ -35,7 +35,7 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -void check_xpu_single_tensor(const at::Tensor& tensor) { +void checkXPUTensor(at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { C10_THROW_ERROR( ValueError, "Tensors must be XPU and dense and non-complex"); @@ -62,19 +62,15 @@ ccl::datatype getXcclDataType( ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { - if (input.scalar_type() == at::kBool) { - if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise - // or. This is to prevent overflow issues with sum, since we use uint8 - // to represent a bool (see xcclDatatypes mapping align with cuda). - return ccl::reduction::max; - } + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); } } @@ -85,7 +81,6 @@ void syncStream( xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); xcclEvent.block(xcclStream); } - } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; @@ -98,8 +93,7 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( : Work(rank, opType, "profilingTitle", inputs), device_(device), workStartTime_(std::chrono::steady_clock::now()) { - unsigned char enable_timing = 0; - xcclEndEvent_ = std::make_shared(enable_timing); + xcclEndEvent_ = std::make_shared(); } ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) @@ -133,12 +127,11 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( currentTimepoint - workStartTime_); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( - "Work ran for ", + "Work ran time out after ", timeElapsed.count(), - " milliseconds before timing out."); + " milliseconds."); TORCH_CHECK(false, exceptionMsg) } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -157,20 +150,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); - - // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI - // launchers. - if (!with_mpirun()) { - int local_rank = getXCCLEnvVar("LOCAL_RANK"); - int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); - if (local_rank == -1 || local_world_size == -1) { - local_rank = rank; - local_world_size = size; - } - setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); - setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); - setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); - } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -189,24 +168,16 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device) { - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the XCCL Communicator since " - "the devices are empty "); - } - + TORCH_CHECK_WITH(DistBackendError, deviceKey.empty(), "Not able to create/get " + "XCCL Communicator since the devices are empty "); { + // todo: why do we need mutex here? std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { return devXCCLCommMap_[deviceKey]; } } - std::shared_ptr XCCLComm; - - XCCL_KVS kvs = get_kvs(rank_, *store_); - int numRanks, rank; numRanks = getSize(); rank = getRank(); @@ -219,14 +190,12 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); - XCCLComm = std::make_shared(std::move(comms[0])); - - { - std::lock_guard lock(mutex_); - devXCCLCommMap_.emplace(deviceKey, XCCLComm); - } + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = std::make_shared(std::move(comms[0])); + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); xcclStreamsMap_.emplace(deviceKey, std::move(stream)); xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); @@ -250,11 +219,9 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( c10::intrusive_ptr work; work = initWork(device, rank_, opType); - work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); - pre(stream, work); for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( @@ -264,7 +231,6 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( post(stream, work); work->xcclEndEvent_->record(stream); - std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -282,7 +248,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( TORCH_CHECK( tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); + checkXPUTensor(tensor); return collective( tensor, tensor, diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 41b78035c83f9..155082fe6d734 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -35,7 +35,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -113,9 +112,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& output, Fn fn, OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; return collective( - input, - output, + inputs, + outputs, fn, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, @@ -124,19 +125,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { opType); } - template - c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - PreProcess pre, - PostProcess post, - OpType opType) { - auto inputs = std::vector{input}; - auto outputs = std::vector{output}; - return collective(inputs, outputs, fn, pre, post, opType); - } - template c10::intrusive_ptr collective( std::vector& inputs, @@ -159,9 +147,11 @@ class TORCH_API ProcessGroupXCCL : public Backend { bool blockingWait_ = false; private: - XCCL_KVS kvs; std::mutex kvs_mutex; - XCCL_KVS get_kvs(int rank, c10d::Store& store) { + ccl::shared_ptr_class kvs; + + ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { + // todo: why do we need the mutex here? std::lock_guard lock(kvs_mutex); if (kvs) return kvs; @@ -186,41 +176,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { return kvs; } }; - -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} - -} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 73e37e0437c45..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,7 +557,7 @@ size_t computeLengthsAndOffsets( return offset; } -inline std::string reduce_op_to_string(c10d::ReduceOp op) { +inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: return "SUM"; From b57e8128390c53e2d58892f69962f13d961119cf Mon Sep 17 00:00:00 2001 From: hanchao Date: Mon, 14 Oct 2024 02:38:30 +0000 Subject: [PATCH 23/30] minor fix --- torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 291a27a53559d..e2d48b1982f25 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -63,7 +63,7 @@ ccl::datatype getXcclDataType( ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { - // Map sum to max for bool tensors to avoid overflow issues with sum. + // Map sum to max for bool tensors to avoid overflow issues with sum. return ccl::reduction::max; } return xcclOps.at(reduceOp); @@ -127,9 +127,7 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( currentTimepoint - workStartTime_); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( - "Work ran time out after ", - timeElapsed.count(), - " milliseconds."); + "Work ran time out after ", timeElapsed.count(), " milliseconds."); TORCH_CHECK(false, exceptionMsg) } std::this_thread::sleep_for( @@ -168,7 +166,10 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device) { - TORCH_CHECK_WITH(DistBackendError, deviceKey.empty(), "Not able to create/get " + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " "XCCL Communicator since the devices are empty "); { // todo: why do we need mutex here? @@ -192,7 +193,8 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( auto xccl_kvs = get_kvs(rank_, *store_); auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - std::shared_ptr XCCLComm = std::make_shared(std::move(comms[0])); + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); From 5968f0fca484c82d5fced80038f184a65957e84c Mon Sep 17 00:00:00 2001 From: hanchao Date: Tue, 15 Oct 2024 02:01:52 +0000 Subject: [PATCH 24/30] update --- cmake/Modules/FindXCCL.cmake | 4 ++-- torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index c1ac5fff78630..18f7ac642d54e 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -8,8 +8,8 @@ include(FindPackageHandleStandardArgs) set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") if (NOT EXISTS "${XCCL_ROOT}") - message(STATUS "Default OneCCL not found, using current environment OneCCL") - set(XCCL_ROOT $ENV{CCL_ROOT}) + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) endif() string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index e2d48b1982f25..28c41c2ba0e03 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -184,7 +184,8 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( rank = getRank(); c10::impl::VirtualGuardImpl impl(device.type()); - c10::Stream stream = impl.getStream(device); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); From edba8aad74e104437b307b9bf09e808eaa06c46b Mon Sep 17 00:00:00 2001 From: hanchao Date: Wed, 16 Oct 2024 07:15:20 +0000 Subject: [PATCH 25/30] update --- .../distributed/c10d/ProcessGroupXCCL.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 28c41c2ba0e03..a57ec69975371 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,6 +1,7 @@ #ifdef USE_C10D_XCCL #include +#include #include #include #include @@ -252,6 +253,25 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); checkXPUTensor(tensor); + + RECORD_PARAM_COMMS_DATA( + // static_cast( + // this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + 1, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + 0, // globalRankStart + 1, // globalRankStride + this->getSize()); // worldSize + return collective( tensor, tensor, From 56a5e7ff6ca99025855e8da554967f6362287ba5 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 17 Oct 2024 00:25:04 +0000 Subject: [PATCH 26/30] Refine specific code --- cmake/Modules/FindXCCL.cmake | 7 +- .../distributed/c10d/ProcessGroupXCCL.cpp | 209 ++++++++---------- .../distributed/c10d/ProcessGroupXCCL.hpp | 194 ++-------------- torch/csrc/distributed/c10d/Utils.hpp | 2 +- 4 files changed, 115 insertions(+), 297 deletions(-) diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake index a717ad1dafc65..18f7ac642d54e 100644 --- a/cmake/Modules/FindXCCL.cmake +++ b/cmake/Modules/FindXCCL.cmake @@ -6,9 +6,10 @@ include(FindPackageHandleStandardArgs) -set(XCCL_ROOT "") -if(DEFINED ENV{CCL_ROOT}) - set(XCCL_ROOT $ENV{CCL_ROOT}) +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) endif() string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index d26d25ae03e39..ef007825a118e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,5 +1,6 @@ #ifdef USE_C10D_XCCL +#include #include #include #include @@ -9,21 +10,20 @@ #include #include -#include #include #include namespace c10d { namespace { -std::map xcclOps = { +const std::map xcclOps = { {ReduceOp::MIN, ccl::reduction::min}, {ReduceOp::MAX, ccl::reduction::max}, {ReduceOp::SUM, ccl::reduction::sum}, {ReduceOp::PRODUCT, ccl::reduction::prod}, }; -std::map xcclDatatypes = { +const std::map xcclDatatypes = { {at::kByte, ccl::datatype::uint8}, {at::kChar, ccl::datatype::int8}, {at::kInt, ccl::datatype::int32}, @@ -35,16 +35,22 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -void check_xpu_single_tensor(const at::Tensor& tensor) { - if (!tensor.is_xpu() || tensor.is_sparse()) { - C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); - } - if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); +void checkXPUTensor(at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } } } -ccl::datatype getXcclDataType(at::ScalarType type) { +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -56,26 +62,27 @@ ccl::datatype getXcclDataType(at::ScalarType type) { ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { try { - if (input.scalar_type() == at::kBool) { - if (reduceOp == ReduceOp::SUM) { - // For bool tensors, map sum to max, which both represent a bitwise or. - // This is to prevent overflow issues with sum, since we use uint8 to - // represent a bool (see xcclDatatypes mapping align with cuda). - return ccl::reduction::max; - } + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( ValueError, - "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); } } +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} } // namespace -static std::mutex xcclCommDevIdxMapMutex; -static std::unordered_map, int> xcclCommDevIdxMap; constexpr int64_t kSynchronizeBusyWaitMillis = 10; ProcessGroupXCCL::WorkXCCL::WorkXCCL( @@ -86,8 +93,7 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( : Work(rank, opType, "profilingTitle", inputs), device_(device), workStartTime_(std::chrono::steady_clock::now()) { - unsigned char enable_timing = 0; - xcclEndEvent_ = std::make_shared(enable_timing); + xcclEndEvent_ = std::make_shared(); } ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) @@ -121,12 +127,9 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( currentTimepoint - workStartTime_); if (timeElapsed >= timeout) { std::string exceptionMsg = c10::str( - "Work ran for ", - timeElapsed.count(), - " milliseconds before timing out."); + "Work ran time out after ", timeElapsed.count(), " milliseconds."); TORCH_CHECK(false, exceptionMsg) } - std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } @@ -145,20 +148,6 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); - - // Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI - // launchers. - if (!with_mpirun()) { - int local_rank = getXCCLEnvVar("LOCAL_RANK"); - int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); - if (local_rank == -1 || local_world_size == -1) { - local_rank = rank; - local_world_size = size; - } - setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); - setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); - setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); - } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; @@ -177,97 +166,74 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, at::Device& device) { - if (deviceKey.empty()) { - C10_THROW_ERROR( - DistBackendError, - "Not able to create/get the XCCL Communicator since " - "the devices are empty "); - } - + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " + "XCCL Communicator since the devices are empty "); { + // todo: why do we need mutex here? std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { return devXCCLCommMap_[deviceKey]; } } - std::shared_ptr XCCLComm; - - XCCL_KVS kvs = get_kvs(rank_, *store_); - int numRanks, rank; numRanks = getSize(); rank = getRank(); c10::impl::VirtualGuardImpl impl(device.type()); - c10::Stream stream = impl.getStream(device); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); auto ctx = ccl::create_context(q.get_context()); ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, kvs); - XCCLComm = std::make_shared(std::move(comms[0])); + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); - { - std::lock_guard lock(mutex_); - inInitializationCommMap_.emplace(deviceKey, XCCLComm); - } + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); - xcclStreams_.emplace(deviceKey, std::move(stream)); - - auto it = inInitializationCommMap_.find(deviceKey); - if (it != inInitializationCommMap_.end()) { - devXCCLCommMap_.emplace(deviceKey, std::move(it->second)); - inInitializationCommMap_.erase(deviceKey); - - xcclCommDevIdxMapMutex.lock(); - xcclCommDevIdxMap.emplace(XCCLComm, device.index()); - xcclCommDevIdxMapMutex.unlock(); - } - - it = devXCCLCommMap_.find(deviceKey); - TORCH_INTERNAL_ASSERT( - it != devXCCLCommMap_.end(), "Communicators not populated in cache!"); - - return it->second; + return XCCLComm; } template c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, OpType opType) { - using traits = function_traits; - using attr_t = typename traits::template arg<2>::type; - attr_t attr = ccl::create_operation_attr(); - - auto device = input.device(); + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); - auto stream = xcclStreams_.at(key); - std::vector outputs{output}; + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType); - - work->outputs_ = - std::make_shared>(std::move(outputs)); - c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), stream); - - auto ccl_stream = ccl::create_stream(stream.queue()); - - fn(input, output, attr, *comm, ccl_stream); + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); + } + post(stream, work); work->xcclEndEvent_->record(stream); - std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -279,51 +245,52 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } -template -c10::intrusive_ptr ProcessGroupXCCL::collective( - at::Tensor& input, - at::Tensor& output, - Fn fn, - OpType opType) { - return collective( - input, - output, - fn, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - [](at::xpu::XPUStream&, - c10::intrusive_ptr& work) {}, - opType); -} - c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { TORCH_CHECK( tensors.size() == 1, "Expecting one tensor only but got multiple"); auto tensor = tensors.back(); - check_xpu_single_tensor(tensor); + checkXPUTensor(tensor); + + RECORD_PARAM_COMMS_DATA( + // static_cast( + // this->getSequenceNumberForGroup() + 1), // seq + 1 to match + // collective + 1, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + 0, // globalRankStart + 1, // globalRankStride + this->getSize()); // worldSize + return collective( tensor, tensor, [&](at::Tensor& input, at::Tensor& output, - ccl::allreduce_attr attr, xcclComm_t& comm, - ccl::stream& stream) { - auto xcclDataType = getXcclDataType(input.scalar_type()); + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - ccl::event ret_evt; - ret_evt = ccl::allreduce( + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( input.data_ptr(), output.data_ptr(), (size_t)input.numel(), xcclDataType, xcclReduceOp, comm, - stream, - attr); - return ret_evt; + ccl_stream); + return; }, OpType::ALLREDUCE); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 99b815f2138b4..5dc003e3dba6b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -35,7 +35,6 @@ static std::vector TORCH_XCCL_BLOCKING_WAIT = { "XCCL_BLOCKING_WAIT"}; using xcclComm_t = ccl::communicator; -using XCCL_KVS = ccl::shared_ptr_class; constexpr const char* XCCL_BACKEND_NAME = "xccl"; class TORCH_API ProcessGroupXCCL : public Backend { @@ -52,11 +51,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { bool isCompleted() override; - bool isSuccess() const override { - TORCH_CHECK( - false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); - } - void abort() override { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); } @@ -70,7 +64,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { } std::vector result() override { - TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); + return *outputs_; } protected: @@ -117,12 +111,24 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType); + OpType opType) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType); + } template c10::intrusive_ptr collective( - at::Tensor& input, - at::Tensor& output, + std::vector& inputs, + std::vector& outputs, Fn fn, PreProcess pre, PostProcess post, @@ -132,141 +138,20 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - c10::intrusive_ptr allreduce_coalesced( - std::vector& tensors, - const AllreduceCoalescedOptions& opts = - AllreduceCoalescedOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allreduce_coalesced not implemented"); - } - - c10::intrusive_ptr reduce( - std::vector& tensors, - const ReduceOptions& opts = ReduceOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce not implemented"); - } - - c10::intrusive_ptr broadcast( - std::vector& tensors, - const BroadcastOptions& opts = BroadcastOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::broadcast not implemented"); - } - - c10::intrusive_ptr allgather( - std::vector>& outputTensors, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather not implemented"); - } - - c10::intrusive_ptr _allgather_base( - at::Tensor& outputbuffer, - at::Tensor& inputbuffer, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::_allgather_base not implemented"); - } - - c10::intrusive_ptr allgather_coalesced( - std::vector>& outputTensorLists, - std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::allgather_coalesced not implemented"); - } - - c10::intrusive_ptr allgather_into_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const AllgatherOptions& opts = AllgatherOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::allgather_into_tensor_coalesced not implemented"); - } - - c10::intrusive_ptr reduce_scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::reduce_scatter not implemented"); - } - - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, "ProcessGroupXCCL::_reduce_scatter_base not implemented"); - } - - c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override { - TORCH_CHECK( - false, - "ProcessGroupXCCL::reduce_scatter_tensor_coalesced not implemented"); - } - - c10::intrusive_ptr barrier( - const BarrierOptions& opts = BarrierOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::barrier not implemented"); - } - - c10::intrusive_ptr alltoall_base( - at::Tensor& outputTensor, - at::Tensor& inputTensor, - std::vector& outputSplitSizes, - std::vector& inputSplitSizes, - const AllToAllOptions& opts = AllToAllOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::alltoall_base not implemented"); - } - - c10::intrusive_ptr alltoall( - std::vector& outputTensors, - std::vector& inputTensors, - const AllToAllOptions& opts = AllToAllOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::alltoall not implemented"); - } - - c10::intrusive_ptr send( - std::vector& tensors, - int dstRank, - int tag) override { - TORCH_CHECK(false, "ProcessGroupXCCL::send not implemented"); - } - - c10::intrusive_ptr recv( - std::vector& tensors, - int srcRank, - int tag) override { - TORCH_CHECK(false, "ProcessGroupXCCL::recv not implemented"); - } - - c10::intrusive_ptr gather( - std::vector>& outputTensors, - std::vector& inputTensors, - const GatherOptions& opts = GatherOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::gather not implemented"); - } - - c10::intrusive_ptr scatter( - std::vector& outputTensors, - std::vector>& inputTensors, - const ScatterOptions& opts = ScatterOptions()) override { - TORCH_CHECK(false, "ProcessGroupXCCL::scatter not implemented"); - } - protected: - std::unordered_map xcclStreams_; - std::unordered_map> - inInitializationCommMap_; + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; std::unordered_map> devXCCLCommMap_; c10::intrusive_ptr store_; std::mutex mutex_; bool blockingWait_ = false; private: - XCCL_KVS kvs; std::mutex kvs_mutex; - XCCL_KVS get_kvs(int rank, c10d::Store& store) { + ccl::shared_ptr_class kvs; + + ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { + // todo: why do we need the mutex here? std::lock_guard lock(kvs_mutex); if (kvs) return kvs; @@ -291,41 +176,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { return kvs; } }; - -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} - -} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index 73e37e0437c45..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,7 +557,7 @@ size_t computeLengthsAndOffsets( return offset; } -inline std::string reduce_op_to_string(c10d::ReduceOp op) { +inline std::string reduceOpToString(c10d::ReduceOp op) { switch (op) { case c10d::ReduceOp::SUM: return "SUM"; From a062f9f8bfc2cb81c4f082515f324b7c5f65dbb8 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 17 Oct 2024 00:59:38 +0000 Subject: [PATCH 27/30] accept comments --- torch/csrc/distributed/c10d/ProcessGroup.hpp | 4 ++-- .../distributed/c10d/ProcessGroupXCCL.cpp | 19 ------------------- .../distributed/c10d/ProcessGroupXCCL.hpp | 3 ++- torch/distributed/distributed_c10d.py | 8 ++------ 4 files changed, 6 insertions(+), 28 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 83d2729fc43d4..31c974a061e4a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -51,8 +51,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { NCCL = 2, UCC = 3, MPI = 4, - CUSTOM = 5, - XCCL = 6, + XCCL = 5, + CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index ef007825a118e..90fb4c3f9cbd7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -253,25 +253,6 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( auto tensor = tensors.back(); checkXPUTensor(tensor); - RECORD_PARAM_COMMS_DATA( - // static_cast( - // this->getSequenceNumberForGroup() + 1), // seq + 1 to match - // collective - 1, - std::make_tuple(pg_uid_, pg_desc_), // PG name tuple - tensors, // inputTensors - tensors, // outputTensors - rank_, // rank - "allreduce", // collective name - tensor.numel(), // inNelems - tensor.numel(), // outNelems - tensor.scalar_type(), // dType - std::vector(), // inSplitSizes - std::vector(), // outSplitSizes - 0, // globalRankStart - 1, // globalRankStride - this->getSize()); // worldSize - return collective( tensor, tensor, diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 5dc003e3dba6b..6e6eb16d62d62 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -138,6 +138,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; + void setSequenceNumberForGroup() override {} + protected: std::unordered_map xcclStreamsMap_; std::unordered_map xcclEventsMap_; @@ -151,7 +153,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { ccl::shared_ptr_class kvs; ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { - // todo: why do we need the mutex here? std::lock_guard lock(kvs_mutex); if (kvs) return kvs; diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4bbb1c4101123..fc4ca55dbd023 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1675,13 +1675,9 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and ( - device_id.index is None - or (device_id.type != "cuda" and device_id.type != "xpu") - ): + if device_id is not None and device_id.index is None: raise ValueError( - "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, xpu, not just cuda or xpu or cpu" + "init_process_group device_id parameter must be a device with an index" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value From ab04fc02de99bdd48a8eaf4aeea27c3cd12ae274 Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 17 Oct 2024 08:17:49 +0000 Subject: [PATCH 28/30] rm header and refine profilehead --- caffe2/CMakeLists.txt | 3 - .../distributed/c10d/ProcessGroupXCCL.cpp | 58 +++++++++++++------ .../distributed/c10d/ProcessGroupXCCL.hpp | 43 ++++++++------ 3 files changed, 65 insertions(+), 39 deletions(-) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b4ec018019f16..25bd7f700f68a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1376,9 +1376,6 @@ if(USE_DISTRIBUTED) endif() if(USE_XPU AND USE_C10D_XCCL) target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) - set_source_files_properties( - ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupXCCL.cpp - PROPERTIES COMPILE_DEFINITIONS "CCL_ENABLE_ZE;CCL_ENABLE_SYCL") endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 90fb4c3f9cbd7..76d265ca5de28 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,17 +1,8 @@ #ifdef USE_C10D_XCCL #include +#include #include -#include -#include -#include -#include -#include -#include -#include - -#include -#include namespace c10d { @@ -89,10 +80,13 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, int rank, OpType opType, + uint64_t seq, + const char* profilingTitle, const std::optional>& inputs) - : Work(rank, opType, "profilingTitle", inputs), + : Work(rank, opType, profilingTitle, inputs), device_(device), - workStartTime_(std::chrono::steady_clock::now()) { + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { xcclEndEvent_ = std::make_shared(); } @@ -101,7 +95,8 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) device_(w.device_), xcclEndEvent_(w.xcclEndEvent_), blockingWait_(w.blockingWait_), - workStartTime_(w.workStartTime_) {} + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; @@ -156,10 +151,16 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, OpType opType, + const char* profilingTitle, const std::vector& inputs, const std::vector& outputs) { auto r = c10::make_intrusive( - device, rank, opType, std::optional>(inputs)); + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); return r; } @@ -212,7 +213,10 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( Fn fn, PreProcess pre, PostProcess post, - OpType opType) { + OpType opType, + const char* profilingTitle) { + seqCollective_++; + auto device = inputs[0].device(); const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); @@ -221,7 +225,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType); + work = initWork(device, rank_, opType, profilingTitle); work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); @@ -253,6 +257,25 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( auto tensor = tensors.back(); checkXPUTensor(tensor); + RECORD_PARAM_COMMS_DATA( + // static_cast( + // this->getSequenceNumberForGroup() + 1), // seq + 1 to match + // collective + 1, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + 0, // globalRankStart + 1, // globalRankStride + this->getSize()); // worldSize + return collective( tensor, tensor, @@ -273,7 +296,8 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( ccl_stream); return; }, - OpType::ALLREDUCE); + OpType::ALLREDUCE, + "xccl:all_reduce"); } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 6e6eb16d62d62..f9761c652dc1a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -1,33 +1,24 @@ #pragma once -#if defined(__linux__) -#include -#include -#include -#include -#endif - #ifdef USE_C10D_XCCL -#include +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + #include -#include #include -#include -#include - -#include -#include #include -#include #include #include -#include #include +#include +#include #include #include #include -#include +#include namespace c10d { static std::vector TORCH_XCCL_BLOCKING_WAIT = { @@ -45,6 +36,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Device& device, int rank, OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, const std::optional>& inputs = std::nullopt); WorkXCCL(const WorkXCCL& w); ~WorkXCCL() override; @@ -63,6 +56,10 @@ class TORCH_API ProcessGroupXCCL : public Backend { return future_; } + uint64_t getSequencenumber() const override { + return seq_; + } + std::vector result() override { return *outputs_; } @@ -72,6 +69,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { std::shared_ptr xcclEndEvent_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; + uint64_t seq_; private: void synchronizeInternal(std::chrono::milliseconds timeout); @@ -103,6 +101,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Device& device, int rank, OpType opType, + const char* profilingTitle = nullptr, const std::vector& inputs = {}, const std::vector& outputs = {}); @@ -111,7 +110,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { at::Tensor& input, at::Tensor& output, Fn fn, - OpType opType) { + OpType opType, + const char* profilingTitle = nullptr) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; return collective( @@ -132,13 +132,17 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, PreProcess pre, PostProcess post, - OpType opType); + OpType opType, + const char* profilingTitle = nullptr); c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; void setSequenceNumberForGroup() override {} + uint64_t getSequenceNumberForGroup() override { + return seqCollective_; + } protected: std::unordered_map xcclStreamsMap_; @@ -147,6 +151,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; std::mutex mutex_; bool blockingWait_ = false; + uint64_t seqCollective_{0}; private: std::mutex kvs_mutex; From 4ee49fba9336c7a6e4d2dbb23012f76a1094c97b Mon Sep 17 00:00:00 2001 From: hanchao Date: Thu, 17 Oct 2024 09:51:06 +0000 Subject: [PATCH 29/30] add get_device_count --- test/distributed/test_c10d_common.py | 13 +++++-------- torch/testing/_internal/common_distributed.py | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 903df26bba9f6..d3cb65f7befb1 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -31,6 +31,7 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, + get_device_count, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -60,17 +61,13 @@ torch.backends.cuda.matmul.allow_tf32 = False -def gpus_for_rank(world_size): +def gpus_for_rank(world_size, backend): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - device_count = ( - torch.xpu.device_count() - if torch.xpu.is_available() - else torch.cuda.device_count() - ) + device_count = get_device_count(backend) visible_devices = list(range(device_count)) gpus_per_process = device_count // world_size gpus_for_rank = [] @@ -833,7 +830,7 @@ def update_parameters(model): def _gpu_model_with_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False, state=None ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -850,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook( def _gpu_model_with_builtin_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9ec38c9ca671c..3e1664690b713 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -93,8 +93,9 @@ class DistTestCases: # Sets showing that something is implemented backend_feature = {} - backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"} backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["cuda"] = {"xccl"} backend_feature["ddp"] = {"nccl", "gloo", "ucc"} backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} backend_feature["plugin"] = set() @@ -462,6 +463,15 @@ def compute_sum(fn, world_size: int): ] ] +# Returns the number of GPUs, currently only for CUDA and XPU. +def get_device_count(backend: str): + assert c10d.is_backend_available(backend) + if backend in backend_feature.get("cuda", set()): + return torch.cuda.device_count() + elif backend in backend_feature.get("xpu", set()): + return torch.xpu.device_count() + else: + raise ValueError(f"Unsupported backend: {backend}") # HELPER FOR MULTIGPU TESTS def init_multigpu_helper(world_size: int, backend: str): @@ -470,7 +480,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() + nGPUs = get_device_count(backend) visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's From 1a2c9c23f81090f7b9855fcb1fbac838bee46158 Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 18 Oct 2024 00:44:48 +0000 Subject: [PATCH 30/30] add backendSupportsSequenceNumbers --- torch/csrc/distributed/c10d/ProcessGroup.hpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 31c974a061e4a..b3eac70e871bf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -131,6 +131,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendType_; }; + inline bool backendSupportsSequenceNumbers(BackendType backendType) { + if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || + backendType == BackendType::XCCL || backendType == BackendType::UCC) + return true; + return false; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -508,10 +515,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual void setSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { getDefaultBackend()->setSequenceNumberForGroup(); } else { TORCH_CHECK( @@ -530,10 +534,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::XCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { TORCH_CHECK(