Skip to content

Commit

Permalink
refine test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Sep 13, 2024
1 parent dc41d6a commit ecbd989
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 9 deletions.
5 changes: 3 additions & 2 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def gpus_for_rank(world_size):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
visible_devices = list(range(torch.cuda.device_count()))
gpus_per_process = torch.cuda.device_count() // world_size
device_count = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
gpus_for_rank = []
for rank in range(world_size):
gpus_for_rank.append(
Expand Down
168 changes: 165 additions & 3 deletions test/distributed/test_c10d_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import os
import random
import sys

import time
import tempfile
from datetime import timedelta
from functools import reduce
from unittest import mock, SkipTest

Expand All @@ -20,6 +23,7 @@
sys.exit(0)

import test_c10d_common
from test_c10d_common import DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook

import torch.distributed as dist
import torch.nn.functional as F
Expand All @@ -29,8 +33,12 @@
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_xccl,
init_multigpu_helper,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
skip_but_pass_in_sandcastle_if,
TEST_XPU,
retry_on_connect_failures,
run_tests,
TestCase,
Expand Down Expand Up @@ -62,10 +70,12 @@ def simple_reduce_tests(rank, world_size):

return tests

TEST_MULTIXPU = torch.xpu.device_count() > 1

class RendezvousEnvTest(TestCase):
@retry_on_connect_failures
@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
def test_common_errors(self):
vars = {
"WORLD_SIZE": "1",
Expand Down Expand Up @@ -164,13 +174,23 @@ def withouts(d, keys):
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
@requires_xccl()
@retry_on_connect_failures
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
def test_default_store_timeout_nccl(self):
self._test_default_store_timeout("xccl")

class ProcessGroupXCCLTest(MultiProcessTestCase):
def _create_process_group_xccl(self):
def _create_process_group_xccl(self, timeout=timedelta(seconds=600), device_id=None):
store = c10d.FileStore(self.file_name, self.world_size)
return c10d.ProcessGroupXCCL(store, self.rank, self.world_size)
c10d.init_process_group(
"xccl",
world_size=self.world_size,
rank=self.rank,
store=store,
timeout=timeout,
device_id=device_id,
)
pg = c10d.distributed_c10d._get_default_group()
return pg

def setUp(self):
super().setUp()
Expand All @@ -182,7 +202,76 @@ def tearDown(self):
os.remove(self.file_name)
except OSError:
pass


@property
def world_size(self):
return 2

@property
def rank_to_GPU(self):
# return rank to GPU map
return init_multigpu_helper(self.world_size, "xccl")

@requires_xccl()
@skip_but_pass_in_sandcastle_if(
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
)
def test_close_multi_pg_unordered(self):
pg = self._create_process_group_xccl()
device = self.rank_to_GPU[self.rank][0]
t = torch.rand(10, 10, device=device)
# First allreduce to initialize default PG's communicator.
pg.allreduce(t).wait()
new_pg1 = c10d.new_group([0, 1])
new_pg2 = c10d.new_group([0, 1])
if self.rank == 0 or self.rank == 1:
t1 = torch.rand(10, 10, device=device)
t2 = torch.rand(10, 10, device=device)
new_pg1.allreduce(t1).wait()
new_pg2.allreduce(t2).wait()
if self.rank == 0:
dist.destroy_process_group(new_pg2)
# force destruction of pg2 first
del new_pg2
dist.destroy_process_group(new_pg1)
del new_pg1
if self.rank == 1:
c10d.destroy_process_group(new_pg1)
# force destruction of pg1 first
del new_pg1
dist.destroy_process_group(new_pg2)
del new_pg2
dist.destroy_process_group()

@requires_xccl()
@skip_but_pass_in_sandcastle_if(
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
)
def test_file_store_check(self):
# self.file_name is created using "delete=False"
# e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
)
pg = dist.distributed_c10d._get_default_group()
self.assertEqual(pg.rank(), self.rank)
self.assertEqual(pg.size(), self.world_size)
# give enough time for check() to be executed multiple times
time.sleep(2)
dist.destroy_process_group()

@requires_xccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs")
def test_set_process_group_desc(self):
device = torch.device(f"xpu:{self.rank}")
pg_default = self._create_process_group_xccl(device_id=device)
self.assertEqual(pg_default.group_desc, "default_pg")
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
self.assertEqual(pg_1.group_desc, "test_purpose")
pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined")

def _test_allreduce_basics(self, fn):
pg = self._create_process_group_xccl()
device = torch.device("xpu:" + str(self.rank))
Expand Down Expand Up @@ -210,6 +299,79 @@ def _test_allreduce_basics(self, fn):
def test_allreduce_basics(self):
self._test_allreduce_basics(lambda t: t.clone())

class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
):
def setUp(self):
super().setUp()
self._spawn_processes()

def _get_process_group(self):
store = self._get_store()
c10d.init_process_group(
"xccl", store=store, rank=self.rank, world_size=self.world_size
)
return c10d.distributed_c10d._get_default_group()

def _test_xccl_backend(
self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False
):
process_group = self._get_process_group()
self._test_ddp_with_process_group(
process_group, devices, device_ids, multi_device, gradient_as_bucket_view
)

@requires_xccl()
@skip_if_lt_x_gpu(2)
def test_xccl_backend_multi_device_ids_not_allowed(self):
int_devices = list(range(torch.xpu.device_count()))
devices = [torch.device("xpu:" + str(i)) for i in int_devices]
with self.assertRaisesRegex(
ValueError, "device_ids can only be None or contain a single element."
):
self._test_xccl_backend(devices, int_devices)

@requires_xccl()
@skip_if_lt_x_gpu(4)
def test_ddp_multi_device_module_config(self):
gpus = gpus_for_rank(self.world_size)[self.rank]

self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process")

process_group = self._get_process_group()

gpus = gpus[:2]
model = DoubleGpuNet(gpus)

with self.assertRaisesRegex(
ValueError,
"DistributedDataParallel device_ids and output_device arguments only work with "
"single-device/multiple-device GPU modules or CPU modules",
):
ddp_model = DistributedDataParallel(
model, output_device=gpus[1], process_group=process_group
)

with self.assertRaisesRegex(
ValueError, "device_ids can only be None or contain a single element."
):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group
)

with self.assertRaisesRegex(
ValueError, "input module must be on the same type of devices"
):
model.fc1 = model.fc1.cpu()
ddp_model = DistributedDataParallel(model, process_group=process_group)

model = model.cpu()
with self.assertRaisesRegex(
ValueError, "device_ids can only be None or contain a single element."
):
ddp_model = DistributedDataParallel(
model, device_ids=gpus, process_group=process_group
)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,10 +1672,10 @@ def _new_process_group_helper(
"created, please use a different group name"
)

if device_id is not None and (device_id.index is None or device_id.type != "cuda"):
if device_id is not None and (device_id.index is None or (device_id.type != "cuda" and device_id.type != "xpu")):
raise ValueError(
"init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, not just cuda or cpu"
"id, e.g. cuda:0, xpu, not just cuda or xpu or cpu"
)

# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
Expand Down
5 changes: 3 additions & 2 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def skip_if_lt_x_gpu(x):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if torch.cuda.is_available() and torch.cuda.device_count() >= x:
if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \
(torch.xpu.is_available() and torch.xpu.device_count() >= x):
return func(*args, **kwargs)
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)

Expand Down Expand Up @@ -469,7 +470,7 @@ def init_multigpu_helper(world_size: int, backend: str):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
nGPUs = torch.cuda.device_count()
nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
visible_devices = range(nGPUs)

# If rank is less than or equal to number of available GPU's
Expand Down

0 comments on commit ecbd989

Please sign in to comment.