Add done() API to Future (#42013)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42013

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D22729596

Pulled By: mrshenli

fbshipit-source-id: ed31021a35af6e2c3393b9b14e4572cf51013bc0
This commit is contained in:
Shen Li 2020-07-24 14:10:33 -07:00 committed by Facebook GitHub Bot
parent 890b52e09f
commit d4736ef95f
6 changed files with 56 additions and 0 deletions

View File

@ -11,6 +11,30 @@ def add_one(fut):
class TestFuture(TestCase):
def test_done(self) -> None:
f = Future[torch.Tensor]()
self.assertFalse(f.done())
f.set_result(torch.ones(2, 2))
self.assertTrue(f.done())
def test_done_exception(self) -> None:
err_msg = "Intentional Value Error"
def raise_exception(unused_future):
raise RuntimeError(err_msg)
f1 = Future[torch.Tensor]()
self.assertFalse(f1.done())
f1.set_result(torch.ones(2, 2))
self.assertTrue(f1.done())
f2 = f1.then(raise_exception)
self.assertTrue(f2.done())
with self.assertRaisesRegex(RuntimeError, err_msg):
f2.wait()
def test_wait(self) -> None:
f = Future[torch.Tensor]()
f.set_result(torch.ones(2, 2))

View File

@ -120,6 +120,7 @@ class _LegacyVariableBase(object):
# Defined in torch/csrc/jit/python/init.cpp
class Future(object):
def __init__(self) -> None: ...
def done(self) -> _bool: ...
def wait(self) -> Any: ...
def then(self, callback: Callable) -> Future: ...
def set_result(self, result: Any) -> None: ...

View File

@ -942,6 +942,10 @@ void initJITBindings(PyObject* module) {
return std::make_shared<PythonFutureWrapper>(
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
}))
.def(
"done",
// Intentionally not releasing GIL
&PythonFutureWrapper::done)
.def(
"wait",
&PythonFutureWrapper::wait,

View File

@ -74,6 +74,10 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
bool done() {
return fut->completed();
}
py::object wait() {
fut->wait();
if (jit::tracer::isTracing()) {

View File

@ -22,6 +22,14 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
def done(self) -> bool:
r"""
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
has a result or an exception.
"""
return super().done()
def wait(self) -> T:
r"""
Block until the value of this ``Future`` is ready.

View File

@ -3105,6 +3105,21 @@ class RpcTest(RpcAgentTestFixture):
with self.assertRaisesRegex(RuntimeError, errMsg):
rpc.remote(dst, fail_on_fut, args=(fut,))
@dist_init
def test_future_done(self):
dst = worker_name((self.rank + 1) % self.world_size)
fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
fut.wait()
self.assertTrue(fut.done())
@dist_init
def test_future_done_exception(self):
dst = worker_name((self.rank + 1) % self.world_size)
fut = rpc.rpc_async(dst, raise_func)
with self.assertRaisesRegex(ValueError, "Expected error"):
fut.wait()
self.assertTrue(fut.done())
def _test_future_cb(self, func):
dst1 = worker_name((self.rank + 1) % self.world_size)
dst2 = worker_name((self.rank + 2) % self.world_size)