Skip to content

Commit

Permalink
Add tests for CUDAFuture (pytorch#56518)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#56518

I don't think we have any tests for CUDAFuture (I couldn't find any, and I didn't write any in the past). I think especially for the two latest features added by this stack we should have a test to ensure they properly work and to catch regressions. (These tests also add indirect coverage for the more "basic" features of CUDAFuture).

I didn't know how/where to add tests for C++ ATen stuff, so instead I added these tests to the Python RPC suite, using the torch.futures.Future wrapper. (It made sense in my mind because RPC is the main user of CUDAFuture). I'll gladly accept pointers to better ways of doing this.
ghstack-source-id: 127295022

Test Plan: The tests themselves.

Reviewed By: mrshenli

Differential Revision: D27887191

fbshipit-source-id: 4ad6d81e676fe486aa8d329591ee1a3818fea059
  • Loading branch information
lw authored and facebook-github-bot committed Apr 24, 2021
1 parent a688b29 commit c416167
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions torch/testing/_internal/distributed/rpc/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_internal_rpc_pickler,
_build_rpc_profiling_key,
)
from torch.futures import Future
from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
captured_output,
Expand Down Expand Up @@ -487,6 +488,32 @@ def inc_and_set(fut):
return ret_future


# A custom Python class that contains a tensor, needed to see if we correctly
# use the Python pickler to extract tensors from non-IValue-convertible types.
class TensorWrapper:
__slots__ = ("tensor",)

def __init__(self, t):
self.tensor = t


# Copied from test/test_cuda.py.
_cycles_per_ms = None

def get_cycles_per_ms():
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
global _cycles_per_ms
if _cycles_per_ms is None:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
torch.cuda._sleep(1000000)
end.record()
end.synchronize()
_cycles_per_ms = 1000000 / start.elapsed_time(end)
return _cycles_per_ms


class AsyncExecutionClass:

@staticmethod
Expand Down Expand Up @@ -5766,3 +5793,111 @@ def test_devices_option_mismatch_reverse(self):
)

rpc.shutdown()

@skip_if_lt_x_gpu(1)
def test_cuda_future_device_as_int(self):
fut = Future(devices=[0])

@skip_if_lt_x_gpu(1)
def test_cuda_future_device_as_str(self):
fut = Future(devices=["cuda:0"])

@skip_if_lt_x_gpu(1)
def test_cuda_future_device_as_device(self):
fut = Future(devices=[torch.device("cuda", 0)])

@skip_if_lt_x_gpu(1)
def test_cuda_future_device_not_cuda(self):
with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "):
fut = Future(devices=["cpu"])

def _test_cuda_future_extraction(self, wrapper, unwrapper):
# We check proper CUDA stream synchronization by filling the tensor with
# the expected value in one stream, and reading it from another stream.
tensor = torch.zeros((100,), device="cuda:0")
future = Future(devices=["cuda:0"])
with torch.cuda.device("cuda:0"):
stream = torch.cuda.Stream()
another_stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
tensor.fill_(1)
future.set_result(wrapper(tensor))
with torch.cuda.stream(another_stream):
self.assertTrue(torch.eq(unwrapper(future.wait()), 1).all().item())

@skip_if_lt_x_gpu(1)
def test_cuda_future_can_extract_cuda_tensor(self):
self._test_cuda_future_extraction(
wrapper=lambda t: t, unwrapper=lambda v: v
)

@skip_if_lt_x_gpu(1)
def test_cuda_future_can_extract_list_with_cuda_tensor(self):
self._test_cuda_future_extraction(
wrapper=lambda t: [t], unwrapper=lambda v: v[0]
)

@skip_if_lt_x_gpu(1)
def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
self._test_cuda_future_extraction(
wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor
)

@skip_if_lt_x_gpu(2)
def test_cuda_future_callback_changes_devices(self):
# We check proper CUDA stream synchronization by filling the tensor with
# the expected value in one stream, and reading it from another stream.
tensor0 = torch.zeros((100,), device="cuda:0")
tensor1 = torch.zeros((100,), device="cuda:1")
parent_future = Future(devices=["cuda:0", "cuda:1"])

def cb(fut):
t0 = fut.value()
tensor1.copy_(t0, non_blocking=True)
return tensor1

child_future = parent_future.then(cb)
with torch.cuda.device("cuda:0"):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
tensor0.fill_(1)
parent_future.set_result(tensor0)
with torch.cuda.device("cuda:1"):
another_stream = torch.cuda.Stream()
with torch.cuda.stream(another_stream):
self.assertTrue(torch.eq(child_future.wait(), 1).all().item())

@skip_if_lt_x_gpu(2)
def test_cuda_future_value_on_bad_device(self):
tensor0 = torch.zeros((100,), device="cuda:0")
tensor1 = torch.zeros((100,), device="cuda:1")
parent_future = Future(devices=["cuda:1"])

# As a plus, we test that futures still invoke callbacks even in case of
# error, and that the child futures are successful if those callbacks
# don't access the parent future.
def cb(fut):
with torch.cuda.device("cuda:1"):
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
tensor1.fill_(1)
return tensor1

child_future = parent_future.then(cb)
with torch.cuda.device("cuda:0"):
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
tensor0.fill_(1)
parent_future.set_result(tensor0)
with self.assertRaisesRegex(
ValueError,
r"The result contained tensors residing on device\(s\) cuda:0 "
r"which are not among the expected device\(s\) cuda:1",
):
parent_future.wait()
with torch.cuda.device("cuda:1"):
another_stream = torch.cuda.Stream()
with torch.cuda.stream(another_stream):
self.assertTrue(torch.eq(child_future.wait(), 1).all().item())

0 comments on commit c416167

Please sign in to comment.