Skip to content

Commit

Permalink
fix number devices
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Mar 12, 2024
1 parent 7ade74f commit 0089a7c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
LAUNCHER_CONFIGS = [
InlineConfig(device_isolation=False),
ProcessConfig(device_isolation=False),
TorchrunConfig(device_isolation=False, nproc_per_node=2),
TorchrunConfig(device_isolation=False, nproc_per_node=4),
]
BACKENDS = ["pytorch", "none"]
DEVICES = ["cpu", "cuda"]
Expand Down Expand Up @@ -113,10 +113,12 @@ def test_api_memory_tracker(device, backend):
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("launcher_config", LAUNCHER_CONFIGS)
def test_api_launch(device, launcher_config):
device_ids = "0,1,2,3" if device == "cuda" else None

benchmark_config = InferenceConfig(latency=True, memory=True)
backend_config = PyTorchConfig(
model="bert-base-uncased",
device_ids="0,1" if device == "cuda" else None,
device_ids=device_ids,
no_weights=True,
device=device,
)
Expand Down

0 comments on commit 0089a7c

Please sign in to comment.