From ecbd9894c4dcca31d8b10746231c3a0d2d155d85 Mon Sep 17 00:00:00 2001 From: hanchao Date: Fri, 13 Sep 2024 06:26:27 +0000 Subject: [PATCH] refine test cases --- test/distributed/test_c10d_common.py | 5 +- test/distributed/test_c10d_xccl.py | 168 +++++++++++++++++- torch/distributed/distributed_c10d.py | 4 +- torch/testing/_internal/common_distributed.py | 5 +- 4 files changed, 173 insertions(+), 9 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 6a0621f3f4991..0c1426d0e29c2 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -66,8 +66,9 @@ def gpus_for_rank(world_size): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + device_count = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() + visible_devices = list(range(device_count)) + gpus_per_process = device_count // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( diff --git a/test/distributed/test_c10d_xccl.py b/test/distributed/test_c10d_xccl.py index 33a2f196c3b5d..a998af7b16ef9 100644 --- a/test/distributed/test_c10d_xccl.py +++ b/test/distributed/test_c10d_xccl.py @@ -7,7 +7,10 @@ import os import random import sys + +import time import tempfile +from datetime import timedelta from functools import reduce from unittest import mock, SkipTest @@ -20,6 +23,7 @@ sys.exit(0) import test_c10d_common +from test_c10d_common import DoubleGpuNet, gpus_for_rank, ModuleForDdpCommHook import torch.distributed as dist import torch.nn.functional as F @@ -29,8 +33,12 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, requires_xccl, + init_multigpu_helper, + skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import ( + skip_but_pass_in_sandcastle_if, + TEST_XPU, retry_on_connect_failures, run_tests, TestCase, @@ -62,10 +70,12 @@ def simple_reduce_tests(rank, world_size): return tests +TEST_MULTIXPU = torch.xpu.device_count() > 1 class RendezvousEnvTest(TestCase): @retry_on_connect_failures @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") def test_common_errors(self): vars = { "WORLD_SIZE": "1", @@ -164,13 +174,23 @@ def withouts(d, keys): class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): @requires_xccl() @retry_on_connect_failures + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") def test_default_store_timeout_nccl(self): self._test_default_store_timeout("xccl") class ProcessGroupXCCLTest(MultiProcessTestCase): - def _create_process_group_xccl(self): + def _create_process_group_xccl(self, timeout=timedelta(seconds=600), device_id=None): store = c10d.FileStore(self.file_name, self.world_size) - return c10d.ProcessGroupXCCL(store, self.rank, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg def setUp(self): super().setUp() @@ -182,7 +202,76 @@ def tearDown(self): os.remove(self.file_name) except OSError: pass - + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + def test_set_process_group_desc(self): + device = torch.device(f"xpu:{self.rank}") + pg_default = self._create_process_group_xccl(device_id=device) + self.assertEqual(pg_default.group_desc, "default_pg") + pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + self.assertEqual(pg_1.group_desc, "test_purpose") + pg_2 = c10d.new_group([0, 1]) + self.assertEqual(pg_2.group_desc, "undefined") + def _test_allreduce_basics(self, fn): pg = self._create_process_group_xccl() device = torch.device("xpu:" + str(self.rank)) @@ -210,6 +299,79 @@ def _test_allreduce_basics(self, fn): def test_allreduce_basics(self): self._test_allreduce_basics(lambda t: t.clone()) +class DistributedDataParallelTest( + test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase +): + def setUp(self): + super().setUp() + self._spawn_processes() + + def _get_process_group(self): + store = self._get_store() + c10d.init_process_group( + "xccl", store=store, rank=self.rank, world_size=self.world_size + ) + return c10d.distributed_c10d._get_default_group() + + def _test_xccl_backend( + self, devices, device_ids, multi_device=False, gradient_as_bucket_view=False + ): + process_group = self._get_process_group() + self._test_ddp_with_process_group( + process_group, devices, device_ids, multi_device, gradient_as_bucket_view + ) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_xccl_backend_multi_device_ids_not_allowed(self): + int_devices = list(range(torch.xpu.device_count())) + devices = [torch.device("xpu:" + str(i)) for i in int_devices] + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + self._test_xccl_backend(devices, int_devices) + + @requires_xccl() + @skip_if_lt_x_gpu(4) + def test_ddp_multi_device_module_config(self): + gpus = gpus_for_rank(self.world_size)[self.rank] + + self.assertTrue(len(gpus) >= 2, "expecting at least 2 gpus per process") + + process_group = self._get_process_group() + + gpus = gpus[:2] + model = DoubleGpuNet(gpus) + + with self.assertRaisesRegex( + ValueError, + "DistributedDataParallel device_ids and output_device arguments only work with " + "single-device/multiple-device GPU modules or CPU modules", + ): + ddp_model = DistributedDataParallel( + model, output_device=gpus[1], process_group=process_group + ) + + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group + ) + + with self.assertRaisesRegex( + ValueError, "input module must be on the same type of devices" + ): + model.fc1 = model.fc1.cpu() + ddp_model = DistributedDataParallel(model, process_group=process_group) + + model = model.cpu() + with self.assertRaisesRegex( + ValueError, "device_ids can only be None or contain a single element." + ): + ddp_model = DistributedDataParallel( + model, device_ids=gpus, process_group=process_group + ) if __name__ == "__main__": diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 3f68609905bb5..d0781765c090f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1672,10 +1672,10 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and (device_id.index is None or (device_id.type != "cuda" and device_id.type != "xpu")): raise ValueError( "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "id, e.g. cuda:0, xpu, not just cuda or xpu or cpu" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index ff83bc8ab6666..554114b7bbcb1 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -180,7 +180,8 @@ def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: + if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ + (torch.xpu.is_available() and torch.xpu.device_count() >= x): return func(*args, **kwargs) sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) @@ -469,7 +470,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.cuda.device_count() + nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's