mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add done() API to Future (#42013)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42013 Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D22729596 Pulled By: mrshenli fbshipit-source-id: ed31021a35af6e2c3393b9b14e4572cf51013bc0
This commit is contained in:
parent
890b52e09f
commit
d4736ef95f
|
|
@ -11,6 +11,30 @@ def add_one(fut):
|
|||
|
||||
|
||||
class TestFuture(TestCase):
|
||||
|
||||
def test_done(self) -> None:
|
||||
f = Future[torch.Tensor]()
|
||||
self.assertFalse(f.done())
|
||||
|
||||
f.set_result(torch.ones(2, 2))
|
||||
self.assertTrue(f.done())
|
||||
|
||||
def test_done_exception(self) -> None:
|
||||
err_msg = "Intentional Value Error"
|
||||
|
||||
def raise_exception(unused_future):
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
f1 = Future[torch.Tensor]()
|
||||
self.assertFalse(f1.done())
|
||||
f1.set_result(torch.ones(2, 2))
|
||||
self.assertTrue(f1.done())
|
||||
|
||||
f2 = f1.then(raise_exception)
|
||||
self.assertTrue(f2.done())
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
f2.wait()
|
||||
|
||||
def test_wait(self) -> None:
|
||||
f = Future[torch.Tensor]()
|
||||
f.set_result(torch.ones(2, 2))
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class _LegacyVariableBase(object):
|
|||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
class Future(object):
|
||||
def __init__(self) -> None: ...
|
||||
def done(self) -> _bool: ...
|
||||
def wait(self) -> Any: ...
|
||||
def then(self, callback: Callable) -> Future: ...
|
||||
def set_result(self, result: Any) -> None: ...
|
||||
|
|
|
|||
|
|
@ -942,6 +942,10 @@ void initJITBindings(PyObject* module) {
|
|||
return std::make_shared<PythonFutureWrapper>(
|
||||
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
|
||||
}))
|
||||
.def(
|
||||
"done",
|
||||
// Intentionally not releasing GIL
|
||||
&PythonFutureWrapper::done)
|
||||
.def(
|
||||
"wait",
|
||||
&PythonFutureWrapper::wait,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,10 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
|
|||
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
|
||||
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
|
||||
|
||||
bool done() {
|
||||
return fut->completed();
|
||||
}
|
||||
|
||||
py::object wait() {
|
||||
fut->wait();
|
||||
if (jit::tracer::isTracing()) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
|||
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||||
also exposes a set of APIs to add callback functions and set results.
|
||||
"""
|
||||
|
||||
def done(self) -> bool:
|
||||
r"""
|
||||
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
|
||||
has a result or an exception.
|
||||
"""
|
||||
return super().done()
|
||||
|
||||
def wait(self) -> T:
|
||||
r"""
|
||||
Block until the value of this ``Future`` is ready.
|
||||
|
|
|
|||
|
|
@ -3105,6 +3105,21 @@ class RpcTest(RpcAgentTestFixture):
|
|||
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||
rpc.remote(dst, fail_on_fut, args=(fut,))
|
||||
|
||||
@dist_init
|
||||
def test_future_done(self):
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
|
||||
fut.wait()
|
||||
self.assertTrue(fut.done())
|
||||
|
||||
@dist_init
|
||||
def test_future_done_exception(self):
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
fut = rpc.rpc_async(dst, raise_func)
|
||||
with self.assertRaisesRegex(ValueError, "Expected error"):
|
||||
fut.wait()
|
||||
self.assertTrue(fut.done())
|
||||
|
||||
def _test_future_cb(self, func):
|
||||
dst1 = worker_name((self.rank + 1) % self.world_size)
|
||||
dst2 = worker_name((self.rank + 2) % self.world_size)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user