Skip to content

Commit

Permalink
register again
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 5, 2024
1 parent 8d739ac commit fb9746b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 12 deletions.
1 change: 1 addition & 0 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ class ProcessGroup:
UNDEFINED = ...
GLOO = ...
NCCL = ...
XCCL = ...
UCC = ...
MPI = ...
CUSTOM = ...
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ static ProcessGroup::BackendType strToBackendType(std::string_view backend) {
return ProcessGroup::BackendType::GLOO;
} else if (backend == "nccl") {
return ProcessGroup::BackendType::NCCL;
} else if (backend == "xccl") {
return ProcessGroup::BackendType::XCCL;
} else if (backend == "ucc") {
return ProcessGroup::BackendType::UCC;
} else if (backend == "mpi") {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
getDefaultBackend()->setSequenceNumberForGroup();
} else {
Expand All @@ -511,6 +512,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
// TODO: HACK for backend name to get sequence number for that backend.
if (backendType == ProcessGroup::BackendType::GLOO ||
backendType == ProcessGroup::BackendType::NCCL ||
backendType == ProcessGroup::BackendType::XCCL ||
backendType == ProcessGroup::BackendType::UCC) {
return getDefaultBackend()->getSequenceNumberForGroup();
} else {
Expand Down
30 changes: 20 additions & 10 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,16 @@ std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm(
c10::impl::VirtualGuardImpl impl(device.type());
c10::Stream stream = impl.getStream(device);
sycl::queue& q = c10::xpu::XPUStream(stream).queue();
auto ctx = ccl::create_context(q.get_context());
devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));
XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs);
// const sycl::context& sycl_ctx = q.get_context();
// sycl::context sycl_ctx = q.get_context();
// ccl::generic_context_type<ccl::cl_backend_type::dpcpp_sycl_l0> ccl_ctx(sycl_ctx);
// auto ctx = ccl::create_context(ccl_ctx.get());

// auto ctx = ccl::create_context(q.get_context());
// devs_rank.emplace_back(rank, ccl::create_device(q.get_device()));
// XCCLComm = ccl::create_communicator(numRanks, devs_rank, ctx, kvs);
XCCLComm = std::make_shared<xcclComm_t>(ccl::create_communicator(numRanks, rank, kvs));


{
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -222,7 +229,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
const auto key = std::to_string(device.index());
auto comm = getXCCLComm(key, device);

auto xcclStream = xcclStreams_.at(key);
auto stream = xcclStreams_.at(key);
std::vector<at::Tensor> inputs{input};
std::vector<at::Tensor> outputs{output};

Expand All @@ -233,14 +240,17 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::collective(
work->outputs_ =
std::make_shared<std::vector<at::Tensor>>(std::move(outputs));
c10::xpu::XPUCachingAllocator::recordStream(
input.storage().data_ptr(), xcclStream);
input.storage().data_ptr(), stream);

auto ccl_stream = ccl::create_stream(xcclStream.queue());
fn(input, output, attr, comm, ccl_stream);
// auto ccl_stream = ccl::create_stream(stream.queue());
auto ccl_stream = ccl::create_stream();

fn(input, output, attr, *comm, ccl_stream);
// fn(input, output, attr, comm, ccl_stream);

work->xcclEndEvent_->record(xcclStream);
work->xcclEndEvent_->record(stream);

std::vector<c10::Stream> streams = {xcclStream.unwrap()};
std::vector<c10::Stream> streams = {stream.unwrap()};
c10::MultiStreamGuard streamGuard(streams);
std::vector<at::Device> devices{device};
work->future_ = c10::make_intrusive<at::ivalue::Future>(
Expand Down Expand Up @@ -283,7 +293,7 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
[&](at::Tensor& input,
at::Tensor& output,
ccl::allreduce_attr attr,
xcclComm_t comm,
xcclComm_t& comm,
ccl::stream& stream) {
ccl::event ret_evt;
ret_evt = ccl::allreduce(
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2237,6 +2237,7 @@ The hook must have the following signature:
.value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED)
.value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO)
.value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL)
.value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL)
.value("UCC", ::c10d::ProcessGroup::BackendType::UCC)
.value("MPI", ::c10d::ProcessGroup::BackendType::MPI)
.value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM)
Expand Down
11 changes: 9 additions & 2 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,22 +252,24 @@ class Backend(str):
NCCL = "nccl"
UCC = "ucc"
MPI = "mpi"
XCCL = "XCCL"
XCCL = "xccl"

_BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"])

_plugins: Dict[str, _BackendPlugin] = {}

backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI]
backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI]

default_device_backend_map: Dict[str, str] = {
"cpu": GLOO,
"cuda": NCCL,
"xpu": XCCL,
}

backend_capability: Dict[str, List[str]] = {
GLOO: ["cpu", "cuda"],
NCCL: ["cuda"],
XCCL: ["xpu"],
UCC: ["cpu", "cuda"],
MPI: ["cpu", "cuda"],
}
Expand All @@ -276,6 +278,7 @@ class Backend(str):
UNDEFINED: ProcessGroup.BackendType.UNDEFINED,
GLOO: ProcessGroup.BackendType.GLOO,
NCCL: ProcessGroup.BackendType.NCCL,
XCCL: ProcessGroup.BackendType.XCCL,
UCC: ProcessGroup.BackendType.UCC,
}

Expand Down Expand Up @@ -1364,6 +1367,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
backends.add(backend) # type: ignore[arg-type]
elif is_gloo_available() and isinstance(backend, ProcessGroupGloo):
backends.add(backend) # type: ignore[arg-type]
if torch.device("xpu") in devices and is_xpu_available():
backend = group._get_backend(torch.device("xpu"))
if isinstance(backend, ProcessGroupXCCL):
backends.add(backend) # type: ignore[arg-type]
if len(backends) == 0:
warnings.warn("Set timeout is now only supported for either nccl or gloo.")
for backend in backends:
Expand Down

0 comments on commit fb9746b

Please sign in to comment.