Skip to content

Commit

Permalink
[RELAND] Add UTs for accelerator device-agnostic runtime APIs (pytorc…
Browse files Browse the repository at this point in the history
…h#133572)

# Motivation
This PR intends to add UTs for accelerator device-agnostic APIs.

# Additional Context
This PR is relanded. It is reverted because `torch.Event` doesn't support mps backend. We have fixed it in pytorch#142468. The previous commit is pytorch@952514f

Pull Request resolved: pytorch#133572
Approved by: https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: pytorch#143171
  • Loading branch information
guangyey authored and pytorchmergebot committed Dec 16, 2024
1 parent c1d4d9d commit 45ac4eb
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
73 changes: 73 additions & 0 deletions test/test_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Owner(s): ["module: tests"]

import sys
import unittest

import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase


if not torch.accelerator.is_available():
print("No available accelerator detected, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811

TEST_MULTIACCELERATOR = torch.accelerator.device_count() > 1


class TestAccelerator(TestCase):
def test_current_accelerator(self):
self.assertTrue(torch.accelerator.is_available())
accelerators = ["cuda", "xpu", "mps"]
for accelerator in accelerators:
if torch.get_device_module(accelerator).is_available():
self.assertEqual(
torch.accelerator.current_accelerator().type, accelerator
)
self.assertIsNone(torch.accelerator.current_accelerator().index)
with self.assertRaisesRegex(
ValueError, "doesn't match the current accelerator"
):
torch.accelerator.set_device_index("cpu")

@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
def test_generic_multi_device_behavior(self):
orig_device = torch.accelerator.current_device_index()
target_device = (orig_device + 1) % torch.accelerator.device_count()

torch.accelerator.set_device_index(target_device)
self.assertEqual(target_device, torch.accelerator.current_device_index())
torch.accelerator.set_device_index(orig_device)
self.assertEqual(orig_device, torch.accelerator.current_device_index())

s1 = torch.Stream(target_device)
torch.accelerator.set_stream(s1)
self.assertEqual(target_device, torch.accelerator.current_device_index())
torch.accelerator.synchronize(orig_device)
self.assertEqual(target_device, torch.accelerator.current_device_index())

def test_generic_stream_behavior(self):
s1 = torch.Stream()
s2 = torch.Stream()
torch.accelerator.set_stream(s1)
self.assertEqual(torch.accelerator.current_stream(), s1)
event = torch.Event()
a = torch.randn(1000)
b = torch.randn(1000)
c = a + b
torch.accelerator.set_stream(s2)
self.assertEqual(torch.accelerator.current_stream(), s2)
a_acc = a.to(torch.accelerator.current_accelerator(), non_blocking=True)
b_acc = b.to(torch.accelerator.current_accelerator(), non_blocking=True)
torch.accelerator.set_stream(s1)
self.assertEqual(torch.accelerator.current_stream(), s1)
event.record(s2)
event.synchronize()
c_acc = a_acc + b_acc
event.record(s2)
torch.accelerator.synchronize()
self.assertTrue(event.query())
self.assertEqual(c_acc.cpu(), c)


if __name__ == "__main__":
run_tests()
8 changes: 8 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,14 @@ def test_generic_stream_event(self):
self.assertTrue(issubclass(type(cuda_event), torch.Event))
self.assertTrue(torch.Event in type(cuda_event).mro())

def test_stream_compatibility(self):
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
torch.accelerator.set_stream(s1)
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
torch.accelerator.set_stream(s2)
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)

def test_record_stream(self):
cycles_per_ms = get_cycles_per_ms()

Expand Down
8 changes: 8 additions & 0 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,14 @@ def test_generic_stream_event(self):
self.assertTrue(issubclass(type(xpu_event), torch.Event))
self.assertTrue(torch.Event in type(xpu_event).mro())

def test_stream_compatibility(self):
s1 = torch.xpu.Stream()
s2 = torch.xpu.Stream()
torch.accelerator.set_stream(s1)
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
torch.accelerator.set_stream(s2)
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)

def test_generator(self):
torch.manual_seed(2024)
g_state0 = torch.xpu.get_rng_state()
Expand Down

0 comments on commit 45ac4eb

Please sign in to comment.