Skip to content

Commit

Permalink
register xccl
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Nov 28, 2024
1 parent 62eea62 commit 5f5ff60
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
4 changes: 4 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,10 @@ if(USE_XPU)
message(WARNING "Failed to include ATen XPU implementation target")
else()
target_link_libraries(torch_xpu PRIVATE torch_xpu_ops)
if(USE_C10D_XCCL)
message(WARNING "USE_C10D_XCCL")
target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL)
endif()
if(MSVC)
# Windows
target_link_options(torch_xpu PRIVATE
Expand Down
21 changes: 21 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif

#ifdef USE_C10D_XCCL
#include <torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
#endif

#ifdef USE_C10D_MPI
#include <torch/csrc/distributed/c10d/ProcessGroupMPI.hpp>
#endif
Expand Down Expand Up @@ -2946,6 +2950,23 @@ Example::
py::call_guard<py::gil_scoped_release>());
#endif

#ifdef USE_C10D_XCCL
auto processGroupXCCL =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>(
module, "ProcessGroupXCCL", backend)
.def(
py::init([](const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size) {
return c10::make_intrusive<::c10d::ProcessGroupXCCL>(
store, rank, size);
}),
py::arg("store"),
py::arg("rank"),
py::arg("size"),
py::call_guard<py::gil_scoped_release>());
#endif

py::enum_<::c10d::OpType>(module, "OpType")
.value("BROADCAST", ::c10d::OpType::BROADCAST)
.value("ALLREDUCE", ::c10d::OpType::ALLREDUCE)
Expand Down
27 changes: 27 additions & 0 deletions torch/xpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,33 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
default_generator = _get_generator(final_device)
return default_generator.get_offset()

def _is_xccl_available():
try:
from torch._C._distributed_c10d import ProcessGroupXCCL
return True
except ImportError:
return False

_XCCL_AVAILABLE = _is_xccl_available()

def _create_process_group_xccl(backend_opts, pg_opts):
if _XCCL_AVAILABLE:
from torch._C._distributed_c10d import ProcessGroupXCCL
return ProcessGroupXCCL(backend_opts.store, backend_opts.group_rank, backend_opts.group_size)
else:
return None

def _init_xccl():
if _XCCL_AVAILABLE:
torch.distributed.Backend.register_backend(
"xccl",
_create_process_group_xccl,
devices=["xpu"],
extended_api=True
)

_init_xccl()


# import here to avoid circular import
from .memory import (
Expand Down

0 comments on commit 5f5ff60

Please sign in to comment.