mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add tests for CUDAFuture (#56518)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56518 I don't think we have any tests for CUDAFuture (I couldn't find any, and I didn't write any in the past). I think especially for the two latest features added by this stack we should have a test to ensure they properly work and to catch regressions. (These tests also add indirect coverage for the more "basic" features of CUDAFuture). I didn't know how/where to add tests for C++ ATen stuff, so instead I added these tests to the Python RPC suite, using the torch.futures.Future wrapper. (It made sense in my mind because RPC is the main user of CUDAFuture). I'll gladly accept pointers to better ways of doing this. ghstack-source-id: 127295022 Test Plan: The tests themselves. Reviewed By: mrshenli Differential Revision: D27887191 fbshipit-source-id: 4ad6d81e676fe486aa8d329591ee1a3818fea059
This commit is contained in:
parent
a688b29750
commit
c416167fb7
|
|
@ -29,6 +29,7 @@ from torch.distributed.rpc.internal import (
|
|||
_internal_rpc_pickler,
|
||||
_build_rpc_profiling_key,
|
||||
)
|
||||
from torch.futures import Future
|
||||
from torch.testing._internal.common_distributed import (
|
||||
skip_if_lt_x_gpu,
|
||||
captured_output,
|
||||
|
|
@ -487,6 +488,32 @@ def async_add_multi_fanout(to, x, num, step):
|
|||
return ret_future
|
||||
|
||||
|
||||
# A custom Python class that contains a tensor, needed to see if we correctly
|
||||
# use the Python pickler to extract tensors from non-IValue-convertible types.
|
||||
class TensorWrapper:
|
||||
__slots__ = ("tensor",)
|
||||
|
||||
def __init__(self, t):
|
||||
self.tensor = t
|
||||
|
||||
|
||||
# Copied from test/test_cuda.py.
|
||||
_cycles_per_ms = None
|
||||
|
||||
def get_cycles_per_ms():
|
||||
"""Approximate number of cycles per millisecond for torch.cuda._sleep"""
|
||||
global _cycles_per_ms
|
||||
if _cycles_per_ms is None:
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
torch.cuda._sleep(1000000)
|
||||
end.record()
|
||||
end.synchronize()
|
||||
_cycles_per_ms = 1000000 / start.elapsed_time(end)
|
||||
return _cycles_per_ms
|
||||
|
||||
|
||||
class AsyncExecutionClass:
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -5766,3 +5793,111 @@ class TensorPipeAgentCudaRpcTest(RpcAgentTestFixture):
|
|||
)
|
||||
|
||||
rpc.shutdown()
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_device_as_int(self):
|
||||
fut = Future(devices=[0])
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_device_as_str(self):
|
||||
fut = Future(devices=["cuda:0"])
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_device_as_device(self):
|
||||
fut = Future(devices=[torch.device("cuda", 0)])
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_device_not_cuda(self):
|
||||
with self.assertRaisesRegex(ValueError, "Expected CUDA devices, got "):
|
||||
fut = Future(devices=["cpu"])
|
||||
|
||||
def _test_cuda_future_extraction(self, wrapper, unwrapper):
|
||||
# We check proper CUDA stream synchronization by filling the tensor with
|
||||
# the expected value in one stream, and reading it from another stream.
|
||||
tensor = torch.zeros((100,), device="cuda:0")
|
||||
future = Future(devices=["cuda:0"])
|
||||
with torch.cuda.device("cuda:0"):
|
||||
stream = torch.cuda.Stream()
|
||||
another_stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
|
||||
tensor.fill_(1)
|
||||
future.set_result(wrapper(tensor))
|
||||
with torch.cuda.stream(another_stream):
|
||||
self.assertTrue(torch.eq(unwrapper(future.wait()), 1).all().item())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_can_extract_cuda_tensor(self):
|
||||
self._test_cuda_future_extraction(
|
||||
wrapper=lambda t: t, unwrapper=lambda v: v
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_can_extract_list_with_cuda_tensor(self):
|
||||
self._test_cuda_future_extraction(
|
||||
wrapper=lambda t: [t], unwrapper=lambda v: v[0]
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_cuda_future_can_extract_custom_class_with_cuda_tensor(self):
|
||||
self._test_cuda_future_extraction(
|
||||
wrapper=lambda t: TensorWrapper(t), unwrapper=lambda v: v.tensor
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cuda_future_callback_changes_devices(self):
|
||||
# We check proper CUDA stream synchronization by filling the tensor with
|
||||
# the expected value in one stream, and reading it from another stream.
|
||||
tensor0 = torch.zeros((100,), device="cuda:0")
|
||||
tensor1 = torch.zeros((100,), device="cuda:1")
|
||||
parent_future = Future(devices=["cuda:0", "cuda:1"])
|
||||
|
||||
def cb(fut):
|
||||
t0 = fut.value()
|
||||
tensor1.copy_(t0, non_blocking=True)
|
||||
return tensor1
|
||||
|
||||
child_future = parent_future.then(cb)
|
||||
with torch.cuda.device("cuda:0"):
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
|
||||
tensor0.fill_(1)
|
||||
parent_future.set_result(tensor0)
|
||||
with torch.cuda.device("cuda:1"):
|
||||
another_stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(another_stream):
|
||||
self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_cuda_future_value_on_bad_device(self):
|
||||
tensor0 = torch.zeros((100,), device="cuda:0")
|
||||
tensor1 = torch.zeros((100,), device="cuda:1")
|
||||
parent_future = Future(devices=["cuda:1"])
|
||||
|
||||
# As a plus, we test that futures still invoke callbacks even in case of
|
||||
# error, and that the child futures are successful if those callbacks
|
||||
# don't access the parent future.
|
||||
def cb(fut):
|
||||
with torch.cuda.device("cuda:1"):
|
||||
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
|
||||
tensor1.fill_(1)
|
||||
return tensor1
|
||||
|
||||
child_future = parent_future.then(cb)
|
||||
with torch.cuda.device("cuda:0"):
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
torch.cuda._sleep(int(1000 * get_cycles_per_ms()))
|
||||
tensor0.fill_(1)
|
||||
parent_future.set_result(tensor0)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"The result contained tensors residing on device\(s\) cuda:0 "
|
||||
r"which are not among the expected device\(s\) cuda:1",
|
||||
):
|
||||
parent_future.wait()
|
||||
with torch.cuda.device("cuda:1"):
|
||||
another_stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(another_stream):
|
||||
self.assertTrue(torch.eq(child_future.wait(), 1).all().item())
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user