Skip to content

Commit

Permalink
add get_device_count
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Oct 17, 2024
1 parent ab04fc0 commit 4ee49fb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
13 changes: 5 additions & 8 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
get_device_count,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
Expand Down Expand Up @@ -60,17 +61,13 @@
torch.backends.cuda.matmul.allow_tf32 = False


def gpus_for_rank(world_size):
def gpus_for_rank(world_size, backend):
"""Multigpu tests are designed to simulate the multi nodes with multi
GPUs on each node. Nccl backend requires equal #GPUs in each process.
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
device_count = (
torch.xpu.device_count()
if torch.xpu.is_available()
else torch.cuda.device_count()
)
device_count = get_device_count(backend)
visible_devices = list(range(device_count))
gpus_per_process = device_count // world_size
gpus_for_rank = []
Expand Down Expand Up @@ -833,7 +830,7 @@ def update_parameters(model):
def _gpu_model_with_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False, state=None
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
Expand All @@ -850,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook(
def _gpu_model_with_builtin_ddp_comm_hook(
self, process_group, hook=None, gradient_as_bucket_view=False
):
device_id = gpus_for_rank(self.world_size)[self.rank][0]
device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0]
gpu_model = DistributedDataParallel(
ModuleForDdpCommHook().to(device_id),
device_ids=[device_id],
Expand Down
14 changes: 12 additions & 2 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ class DistTestCases:

# Sets showing that something is implemented
backend_feature = {}
backend_feature["gpu"] = {"nccl", "gloo", "ucc"}
backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"}
backend_feature["cuda"] = {"nccl", "gloo", "ucc"}
backend_feature["cuda"] = {"xccl"}
backend_feature["ddp"] = {"nccl", "gloo", "ucc"}
backend_feature["subgroup"] = {"nccl", "gloo", "ucc"}
backend_feature["plugin"] = set()
Expand Down Expand Up @@ -462,6 +463,15 @@ def compute_sum(fn, world_size: int):
]
]

# Returns the number of GPUs, currently only for CUDA and XPU.
def get_device_count(backend: str):
assert c10d.is_backend_available(backend)
if backend in backend_feature.get("cuda", set()):
return torch.cuda.device_count()
elif backend in backend_feature.get("xpu", set()):
return torch.xpu.device_count()
else:
raise ValueError(f"Unsupported backend: {backend}")

# HELPER FOR MULTIGPU TESTS
def init_multigpu_helper(world_size: int, backend: str):
Expand All @@ -470,7 +480,7 @@ def init_multigpu_helper(world_size: int, backend: str):
On a single node, all visible GPUs are evenly
divided to subsets, each process only uses a subset.
"""
nGPUs = torch.xpu.device_count() if torch.xpu.is_available() else torch.cuda.device_count()
nGPUs = get_device_count(backend)
visible_devices = range(nGPUs)

# If rank is less than or equal to number of available GPU's
Expand Down

0 comments on commit 4ee49fb

Please sign in to comment.