mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40233 There was a question earlier whether torch.futures.wait_all() would raised if the underlying futures raise (it was supposed to, but no test coverage). This change adds a couple very basic torch.futures.collect_all/ wait_all tests. ghstack-source-id: 106168134 Test Plan: buck test mode/dev-nosan caffe2/test:futures Differential Revision: D22120284 fbshipit-source-id: 3a8edae5dbf8c58c8361eff156c386a684ec5e86
135 lines
3.5 KiB
Python
135 lines
3.5 KiB
Python
import threading
|
|
import time
|
|
import torch
|
|
from torch.futures import Future
|
|
from torch.testing._internal.common_utils import TestCase, TemporaryFileName
|
|
|
|
|
|
def add_one(fut):
|
|
return fut.wait() + 1
|
|
|
|
|
|
class TestFuture(TestCase):
|
|
def test_wait(self):
|
|
f = Future()
|
|
f.set_result(torch.ones(2, 2))
|
|
|
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
|
|
|
def test_wait_multi_thread(self):
|
|
|
|
def slow_set_future(fut, value):
|
|
time.sleep(0.5)
|
|
fut.set_result(value)
|
|
|
|
f = Future()
|
|
|
|
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
|
|
t.start()
|
|
|
|
self.assertEqual(f.wait(), torch.ones(2, 2))
|
|
t.join()
|
|
|
|
def test_mark_future_twice(self):
|
|
fut = Future()
|
|
fut.set_result(1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Future can only be marked completed once"
|
|
):
|
|
fut.set_result(1)
|
|
|
|
def test_pickle_future(self):
|
|
fut = Future()
|
|
errMsg = "Can not pickle torch.futures.Future"
|
|
with TemporaryFileName() as fname:
|
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
|
torch.save(fut, fname)
|
|
|
|
def test_then(self):
|
|
fut = Future()
|
|
then_fut = fut.then(lambda x: x.wait() + 1)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
self.assertEqual(fut.wait(), torch.ones(2, 2))
|
|
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
|
|
|
|
def test_chained_then(self):
|
|
fut = Future()
|
|
futs = []
|
|
last_fut = fut
|
|
for _ in range(20):
|
|
last_fut = last_fut.then(add_one)
|
|
futs.append(last_fut)
|
|
|
|
fut.set_result(torch.ones(2, 2))
|
|
|
|
for i in range(len(futs)):
|
|
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
|
|
|
|
def _test_error(self, cb, errMsg):
|
|
fut = Future()
|
|
then_fut = fut.then(cb)
|
|
|
|
fut.set_result(5)
|
|
self.assertEqual(5, fut.wait())
|
|
with self.assertRaisesRegex(RuntimeError, errMsg):
|
|
then_fut.wait()
|
|
|
|
def test_then_wrong_arg(self):
|
|
|
|
def wrong_arg(tensor):
|
|
return tensor + 1
|
|
|
|
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
|
|
|
|
def test_then_no_arg(self):
|
|
|
|
def no_arg():
|
|
return True
|
|
|
|
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
|
|
|
|
def test_then_raise(self):
|
|
|
|
def raise_value_error(fut):
|
|
raise ValueError("Expected error")
|
|
|
|
self._test_error(raise_value_error, "Expected error")
|
|
|
|
def test_collect_all(self):
|
|
fut1 = Future()
|
|
fut2 = Future()
|
|
fut_all = torch.futures.collect_all([fut1, fut2])
|
|
|
|
def slow_in_thread(fut, value):
|
|
time.sleep(0.1)
|
|
fut.set_result(value)
|
|
|
|
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
|
|
fut2.set_result(2)
|
|
t.start()
|
|
|
|
res = fut_all.wait()
|
|
self.assertEqual(res[0].wait(), 1)
|
|
self.assertEqual(res[1].wait(), 2)
|
|
t.join()
|
|
|
|
def test_wait_all(self):
|
|
fut1 = Future()
|
|
fut2 = Future()
|
|
|
|
# No error version
|
|
fut1.set_result(1)
|
|
fut2.set_result(2)
|
|
res = torch.futures.wait_all([fut1, fut2])
|
|
print(res)
|
|
self.assertEqual(res, [1, 2])
|
|
|
|
# Version with an exception
|
|
def raise_in_fut(fut):
|
|
raise ValueError("Expected error")
|
|
fut3 = fut1.then(raise_in_fut)
|
|
with self.assertRaisesRegex(RuntimeError, "Expected error"):
|
|
torch.futures.wait_all([fut3, fut2])
|