Skip to content

Commit

Permalink
Merge branch 'xccl' into xccl-group
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 13, 2024
2 parents 009e334 + 2d1ae87 commit 04226de
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 16 deletions.
5 changes: 3 additions & 2 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def gpus_for_rank(world_size):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
device_count = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
Expand Down
95 changes: 92 additions & 3 deletions test/distributed/test_c10d_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import os
import random
import sys

import time
import tempfile
from datetime import timedelta
from functools import reduce
from unittest import mock, SkipTest

Expand All @@ -20,6 +23,7 @@
sys.exit(0)

import test_c10d_common
from test_c10d_common import DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook

import torch.distributed as dist
import torch.nn.functional as F
Expand All @@ -29,8 +33,12 @@
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_xccl,
init_multigpu_helper,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
skip_but_pass_in_sandcastle_if,
TEST_XPU,
retry_on_connect_failures,
run_tests,
TestCase,
Expand Down Expand Up @@ -62,10 +70,12 @@ def simple_reduce_tests(rank, world_size):

return tests

TEST_MULTIXPU = torch.xpu.device_count() > 1

class RendezvousEnvTest(TestCase):
@retry_on_connect_failures
@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
def test_common_errors(self):
vars = {
"WORLD_SIZE": "1",
Expand Down Expand Up @@ -164,13 +174,23 @@ def withouts(d, keys):
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
@requires_xccl()
@retry_on_connect_failures
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
def test_default_store_timeout_nccl(self):
self._test_default_store_timeout("xccl")

class ProcessGroupXCCLTest(MultiProcessTestCase):
def _create_process_group_xccl(self):
def _create_process_group_xccl(self, timeout=timedelta(seconds=600), device_id=None):
store = c10d.FileStore(self.file_name, self.world_size)
return c10d.ProcessGroupXCCL(store, self.rank, self.world_size)
c10d.init_process_group(
"xccl",
world_size=self.world_size,
rank=self.rank,
store=store,
timeout=timeout,
device_id=device_id,
)
pg = c10d.distributed_c10d._get_default_group()
return pg

def setUp(self):
super().setUp()
Expand All @@ -182,7 +202,76 @@ def tearDown(self):
os.remove(self.file_name)
except OSError:
pass


@property
def world_size(self):
return 2

@property
def rank_to_GPU(self):
# return rank to GPU map
return init_multigpu_helper(self.world_size, "xccl")

@requires_xccl()
@skip_but_pass_in_sandcastle_if(
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
)
def test_close_multi_pg_unordered(self):
pg = self._create_process_group_xccl()
device = self.rank_to_GPU[self.rank][0]
t = torch.rand(10, 10, device=device)
# First allreduce to initialize default PG's communicator.
pg.allreduce(t).wait()
new_pg1 = c10d.new_group([0, 1])
new_pg2 = c10d.new_group([0, 1])
if self.rank == 0 or self.rank == 1:
t1 = torch.rand(10, 10, device=device)
t2 = torch.rand(10, 10, device=device)
new_pg1.allreduce(t1).wait()
new_pg2.allreduce(t2).wait()
if self.rank == 0:
dist.destroy_process_group(new_pg2)
# force destruction of pg2 first
del new_pg2
dist.destroy_process_group(new_pg1)
del new_pg1
if self.rank == 1:
c10d.destroy_process_group(new_pg1)
# force destruction of pg1 first
del new_pg1
dist.destroy_process_group(new_pg2)
del new_pg2
dist.destroy_process_group()

@requires_xccl()
@skip_but_pass_in_sandcastle_if(
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
)
def test_file_store_check(self):
# self.file_name is created using "delete=False"
# e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
)
pg = dist.distributed_c10d._get_default_group()
self.assertEqual(pg.rank(), self.rank)
self.assertEqual(pg.size(), self.world_size)
# give enough time for check() to be executed multiple times
time.sleep(2)
dist.destroy_process_group()

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs")
def test_set_process_group_desc(self):
device = torch.device(f"xpu:{self.rank}")
pg_default = self._create_process_group_xccl(device_id=device)
self.assertEqual(pg_default.group_desc, "default_pg")
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
self.assertEqual(pg_1.group_desc, "test_purpose")
pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined")

def _test_allreduce_basics(self, fn):
pg = self._create_process_group_xccl()
device = torch.device("xpu:" + str(self.rank))
Expand Down
11 changes: 8 additions & 3 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ std::map<c10d::ReduceOp, ccl::reduction> xcclOps = {
std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kByte, ccl::datatype::uint8},
{at::kChar, ccl::datatype::int8},
{at::kShort, ccl::datatype::int16},
{at::kInt, ccl::datatype::int32},
{at::kLong, ccl::datatype::int64},
{at::kHalf, ccl::datatype::float16},
Expand Down Expand Up @@ -220,13 +219,19 @@ bool ProcessGroupXCCL::WorkXCCL::checkTimeout(
return true;
}

void ProcessGroupXCCL::WorkXCCL::finishWorkXcclError(
const std::exception_ptr& eptr) {
future_->setError(eptr);
finish(eptr);
}

bool ProcessGroupXCCL::WorkXCCL::isCompleted() {
for (auto& ret : rets) {
bool flag;
try {
TORCH_CHECK(flag = ret.test());
} catch (...) {
finishAWorkXCCLError(std::current_exception());
finishWorkXcclError(std::current_exception());
return true;
}
if (!flag) {
Expand Down Expand Up @@ -287,7 +292,7 @@ ProcessGroupXCCL::ProcessGroupXCCL(
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false);
init();

{
if (!with_mpirun()) {
int local_rank = getXCCLEnvVar("LOCAL_RANK");
int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE");
if (local_rank == -1 || local_world_size == -1) {
Expand Down
13 changes: 9 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void setXCCLEnvVar(std::string envVarName, int val) {
void setXCCLEnvVar(std::string envVarName, std::string val) {
setenv(envVarName.c_str(), val.c_str(), 1);
}

bool with_mpirun() {
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
? true
: false;
}
} // namespace

static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
Expand Down Expand Up @@ -88,6 +95,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
rets.push_back(std::move(result));
}

void finishWorkXcclError(const std::exception_ptr& eptr);

bool isCompleted() override;

bool isSuccess() const override {
Expand Down Expand Up @@ -124,10 +133,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::vector<ccl::event> rets;

private:
void finishAWorkXCCLError(std::exception_ptr eptr) {
future_->setError(eptr);
finish(eptr);
}
void synchronizeInternal(std::chrono::milliseconds timeout);
std::shared_ptr<std::vector<at::Tensor>> outputs_;
c10::intrusive_ptr<at::ivalue::Future> future_;
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,10 +1672,10 @@ def _new_process_group_helper(
"created, please use a different group name"
)

if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
if device_id is not None and (device_id.index is None or (device_id.type != "cuda" and device_id.type != "xpu")):
raise ValueError(
"init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, not just cuda or cpu"
"id, e.g. cuda:0, xpu, not just cuda or xpu or cpu"
)

# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
Expand Down
5 changes: 3 additions & 2 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \
(torch.xpu.is_available() and torch.xpu.device_count() >= x):
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)

Expand Down Expand Up @@ -469,7 +470,7 @@ def init_multigpu_helper(world_size: int, backend: str):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
nGPUs = torch.cuda.device_count()
nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
visible_devices = range(nGPUs)

# If rank is less than or equal to number of available GPU's
Expand Down

0 comments on commit 04226de

Please sign in to comment.