diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 903df26bba9f6..d3cb65f7befb1 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -31,6 +31,7 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, + get_device_count, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -60,17 +61,13 @@ torch.backends.cuda.matmul.allow_tf32 = False -def gpus_for_rank(world_size): +def gpus_for_rank(world_size, backend): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - device_count = ( - torch.xpu.device_count() - if torch.xpu.is_available() - else torch.cuda.device_count() - ) + device_count = get_device_count(backend) visible_devices = list(range(device_count)) gpus_per_process = device_count // world_size gpus_for_rank = [] @@ -833,7 +830,7 @@ def update_parameters(model): def _gpu_model_with_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False, state=None ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -850,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook( def _gpu_model_with_builtin_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9ec38c9ca671c..3e1664690b713 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -93,8 +93,9 @@ class DistTestCases: # Sets showing that something is implemented backend_feature = {} - backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"} backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["cuda"] = {"xccl"} backend_feature["ddp"] = {"nccl", "gloo", "ucc"} backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} backend_feature["plugin"] = set() @@ -462,6 +463,15 @@ def compute_sum(fn, world_size: int): ] ] +# Returns the number of GPUs, currently only for CUDA and XPU. +def get_device_count(backend: str): + assert c10d.is_backend_available(backend) + if backend in backend_feature.get("cuda", set()): + return torch.cuda.device_count() + elif backend in backend_feature.get("xpu", set()): + return torch.xpu.device_count() + else: + raise ValueError(f"Unsupported backend: {backend}") # HELPER FOR MULTIGPU TESTS def init_multigpu_helper(world_size: int, backend: str): @@ -470,7 +480,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count() + nGPUs = get_device_count(backend) visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's