diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 0c97185519d28f..53011cde6b178a 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -309,6 +309,7 @@ class ProcessGroup: UNDEFINED = ... GLOO = ... NCCL = ... + XCCL = ... UCC = ... MPI = ... CUSTOM = ... diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 75635bc68aed4f..70356b3bf382ce 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -21,6 +21,8 @@ static ProcessGroup::BackendType strToBackendType(std::string_view backend) { return ProcessGroup::BackendType::GLOO; } else if (backend == "nccl") { return ProcessGroup::BackendType::NCCL; + } else if (backend == "xccl") { + return ProcessGroup::BackendType::XCCL; } else if (backend == "ucc") { return ProcessGroup::BackendType::UCC; } else if (backend == "mpi") { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 85142caf0ac7c7..73fc2bda701327 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -490,6 +490,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { getDefaultBackend()->setSequenceNumberForGroup(); } else { @@ -511,6 +512,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // TODO: HACK for backend name to get sequence number for that backend. if (backendType == ProcessGroup::BackendType::GLOO || backendType == ProcessGroup::BackendType::NCCL || + backendType == ProcessGroup::BackendType::XCCL || backendType == ProcessGroup::BackendType::UCC) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index ffd566f10f854a..e21be88ef83d16 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -178,9 +178,16 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = impl.getStream(device); sycl::queue& q = c10::xpu::XPUStream(stream).queue(); - auto ctx = ccl::create_context(q.get_context()); - devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs); + // const sycl::context& sycl_ctx = q.get_context(); + // sycl::context sycl_ctx = q.get_context(); + // ccl::generic_context_type ccl_ctx(sycl_ctx); + // auto ctx = ccl::create_context(ccl_ctx.get()); + + // auto ctx = ccl::create_context(q.get_context()); + // devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + // XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs); + XCCLComm = std::make_shared(ccl::create_communicator(numRanks, rank, kvs)); + { std::lock_guard lock(mutex_); @@ -222,7 +229,7 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( const auto key = std::to_string(device.index()); auto comm = getXCCLComm(key, device); - auto xcclStream = xcclStreams_.at(key); + auto stream = xcclStreams_.at(key); std::vector inputs{input}; std::vector outputs{output}; @@ -233,14 +240,17 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( work->outputs_ = std::make_shared>(std::move(outputs)); c10::xpu::XPUCachingAllocator::recordStream( - input.storage().data_ptr(), xcclStream); + input.storage().data_ptr(), stream); - auto ccl_stream = ccl::create_stream(xcclStream.queue()); - fn(input, output, attr, comm, ccl_stream); + // auto ccl_stream = ccl::create_stream(stream.queue()); + auto ccl_stream = ccl::create_stream(); + + fn(input, output, attr, *comm, ccl_stream); + // fn(input, output, attr, comm, ccl_stream); - work->xcclEndEvent_->record(xcclStream); + work->xcclEndEvent_->record(stream); - std::vector streams = {xcclStream.unwrap()}; + std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; work->future_ = c10::make_intrusive( @@ -283,7 +293,7 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( [&](at::Tensor& input, at::Tensor& output, ccl::allreduce_attr attr, - xcclComm_t comm, + xcclComm_t& comm, ccl::stream& stream) { ccl::event ret_evt; ret_evt = ccl::allreduce( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index e12e96f9fe882f..5d200bb6eeb9cf 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2237,6 +2237,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 26cb1cda1db8cb..3f68609905bb5a 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -252,22 +252,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" - XCCL = "XCCL" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -276,6 +278,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, } @@ -1364,6 +1367,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xpu_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: