From c416167fb7d93f089133e340cd21ed8ef1b95992 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Sat, 24 Apr 2021 07:05:09 -0700 Subject: [PATCH] Add tests for CUDAFuture (#56518) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56518 I don't think we have any tests for CUDAFuture (I couldn't find any, and I didn't write any in the past). I think especially for the two latest features added by this stack we should have a test to ensure they properly work and to catch regressions. (These tests also add indirect coverage for the more "basic" features of CUDAFuture). I didn't know how/where to add tests for C++ ATen stuff, so instead I added these tests to the Python RPC suite, using the torch.futures.Future wrapper. (It made sense in my mind because RPC is the main user of CUDAFuture). I'll gladly accept pointers to better ways of doing this. ghstack-source-id: 127295022 Test Plan: The tests themselves. Reviewed By: mrshenli Differential Revision: D27887191 fbshipit-source-id: 4ad6d81e676fe486aa8d329591ee1a3818fea059 --- .../_internal/distributed/rpc/rpc_test.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 77447d03f941e..2578f20f24a59 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -29,6 +29,7 @@ _internal_rpc_pickler, _build_rpc_profiling_key, ) +from torch.futures import Future from torch.testing._internal.common_distributed import ( skip_if_lt_x_gpu, captured_output, @@ -487,6 +488,32 @@ def inc_and_set(fut): return ret_future +# A custom Python class that contains a tensor, needed to see if we correctly +# use the Python pickler to extract tensors from non-IValue-convertible types. +class TensorWrapper: + __slots__ = ("tensor",) + + def __init__(self, t): + self.tensor = t + + +# Copied from test/test_cuda.py. +_cycles_per_ms = None + +def get_cycles_per_ms(): + """Approximate number of cycles per millisecond for torch.cuda._sleep""" + global _cycles_per_ms + if _cycles_per_ms is None: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + _cycles_per_ms = 1000000 / start.elapsed_time(end) + return _cycles_per_ms + + class AsyncExecutionClass: @staticmethod @@ -5766,3 +5793,111 @@ def test_devices_option_mismatch_reverse(self): ) rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_int(self): + fut = Future(devices=[0]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_str(self): + fut = Future(devices=["cuda:0"]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_device(self): + fut = Future(devices=[torch.device("cuda", 0)]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_not_cuda(self): + with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "): + fut = Future(devices=["cpu"]) + + def _test_cuda_future_extraction(self, wrapper, unwrapper): + # We check proper CUDA stream synchronization by filling the tensor with + # the expected value in one stream, and reading it from another stream. + tensor = torch.zeros((100,), device="cuda:0") + future = Future(devices=["cuda:0"]) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + another_stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor.fill_(1) + future.set_result(wrapper(tensor)) + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(unwrapper(future.wait()), 1).all().item()) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: t, unwrapper=lambda v: v + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_list_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: [t], unwrapper=lambda v: v[0] + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor + ) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_callback_changes_devices(self): + # We check proper CUDA stream synchronization by filling the tensor with + # the expected value in one stream, and reading it from another stream. + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:0", "cuda:1"]) + + def cb(fut): + t0 = fut.value() + tensor1.copy_(t0, non_blocking=True) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_value_on_bad_device(self): + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:1"]) + + # As a plus, we test that futures still invoke callbacks even in case of + # error, and that the child futures are successful if those callbacks + # don't access the parent future. + def cb(fut): + with torch.cuda.device("cuda:1"): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor1.fill_(1) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with self.assertRaisesRegex( + ValueError, + r"The result contained tensors residing on device\(s\) cuda:0 " + r"which are not among the expected device\(s\) cuda:1", + ): + parent_future.wait() + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item())