diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 77447d03f94..2578f20f24a 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -29,6 +29,7 @@ from torch.distributed.rpc.internal import ( _internal_rpc_pickler, _build_rpc_profiling_key, ) +from torch.futures import Future from torch.testing._internal.common_distributed import ( skip_if_lt_x_gpu, captured_output, @@ -487,6 +488,32 @@ def async_add_multi_fanout(to, x, num, step): return ret_future +# A custom Python class that contains a tensor, needed to see if we correctly +# use the Python pickler to extract tensors from non-IValue-convertible types. +class TensorWrapper: + __slots__ = ("tensor",) + + def __init__(self, t): + self.tensor = t + + +# Copied from test/test_cuda.py. +_cycles_per_ms = None + +def get_cycles_per_ms(): + """Approximate number of cycles per millisecond for torch.cuda._sleep""" + global _cycles_per_ms + if _cycles_per_ms is None: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + torch.cuda._sleep(1000000) + end.record() + end.synchronize() + _cycles_per_ms = 1000000 / start.elapsed_time(end) + return _cycles_per_ms + + class AsyncExecutionClass: @staticmethod @@ -5766,3 +5793,111 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture): ) rpc.shutdown() + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_int(self): + fut = Future(devices=[0]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_str(self): + fut = Future(devices=["cuda:0"]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_as_device(self): + fut = Future(devices=[torch.device("cuda", 0)]) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_device_not_cuda(self): + with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "): + fut = Future(devices=["cpu"]) + + def _test_cuda_future_extraction(self, wrapper, unwrapper): + # We check proper CUDA stream synchronization by filling the tensor with + # the expected value in one stream, and reading it from another stream. + tensor = torch.zeros((100,), device="cuda:0") + future = Future(devices=["cuda:0"]) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + another_stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor.fill_(1) + future.set_result(wrapper(tensor)) + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(unwrapper(future.wait()), 1).all().item()) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: t, unwrapper=lambda v: v + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_list_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: [t], unwrapper=lambda v: v[0] + ) + + @skip_if_lt_x_gpu(1) + def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self): + self._test_cuda_future_extraction( + wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor + ) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_callback_changes_devices(self): + # We check proper CUDA stream synchronization by filling the tensor with + # the expected value in one stream, and reading it from another stream. + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:0", "cuda:1"]) + + def cb(fut): + t0 = fut.value() + tensor1.copy_(t0, non_blocking=True) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item()) + + @skip_if_lt_x_gpu(2) + def test_cuda_future_value_on_bad_device(self): + tensor0 = torch.zeros((100,), device="cuda:0") + tensor1 = torch.zeros((100,), device="cuda:1") + parent_future = Future(devices=["cuda:1"]) + + # As a plus, we test that futures still invoke callbacks even in case of + # error, and that the child futures are successful if those callbacks + # don't access the parent future. + def cb(fut): + with torch.cuda.device("cuda:1"): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor1.fill_(1) + return tensor1 + + child_future = parent_future.then(cb) + with torch.cuda.device("cuda:0"): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + torch.cuda._sleep(int(1000 * get_cycles_per_ms())) + tensor0.fill_(1) + parent_future.set_result(tensor0) + with self.assertRaisesRegex( + ValueError, + r"The result contained tensors residing on device\(s\) cuda:0 " + r"which are not among the expected device\(s\) cuda:1", + ): + parent_future.wait() + with torch.cuda.device("cuda:1"): + another_stream = torch.cuda.Stream() + with torch.cuda.stream(another_stream): + self.assertTrue(torch.eq(child_future.wait(), 1).all().item())