Enable typechecking for torch.futures (#41675)

Summary:
Add typing declarations for torch._C.Future and torch._C._collect_all

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41675

Reviewed By: izdeby

Differential Revision: D22627539

Pulled By: malfet

fbshipit-source-id: 29b87685d65dd24ee2094bae8a84a0fe3787e7f8
This commit is contained in:
Nikita Shulga 2020-07-23 23:04:45 -07:00 committed by Facebook GitHub Bot
parent 750d9dea49
commit c0bfa45f9d
4 changed files with 53 additions and 60 deletions

View File

@ -19,6 +19,7 @@ files =
caffe2,
aten/src/ATen/function_wrapper.py,
test/test_complex.py,
test/test_futures.py,
test/test_torch.py,
test/test_type_hints.py,
test/test_type_info.py
@ -53,9 +54,6 @@ ignore_errors = True
[mypy-torch.functional.*]
ignore_errors = True
[mypy-torch.futures.*]
ignore_errors = True
[mypy-torch.testing._internal.*]
ignore_errors = True

View File

@ -11,19 +11,19 @@ def add_one(fut):
class TestFuture(TestCase):
def test_wait(self):
f = Future()
def test_wait(self) -> None:
f = Future[torch.Tensor]()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self):
def test_wait_multi_thread(self) -> None:
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future()
f = Future[torch.Tensor]()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
@ -31,8 +31,8 @@ class TestFuture(TestCase):
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self):
fut = Future()
def test_mark_future_twice(self) -> None:
fut = Future[int]()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
@ -41,14 +41,14 @@ class TestFuture(TestCase):
fut.set_result(1)
def test_pickle_future(self):
fut = Future()
fut = Future[int]()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future()
fut = Future[torch.Tensor]()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
@ -56,7 +56,7 @@ class TestFuture(TestCase):
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future()
fut = Future[torch.Tensor]()
futs = []
last_fut = fut
for _ in range(20):
@ -69,7 +69,7 @@ class TestFuture(TestCase):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future()
fut = Future[int]()
then_fut = fut.then(cb)
fut.set_result(5)
@ -99,8 +99,8 @@ class TestFuture(TestCase):
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future()
fut2 = Future()
fut1 = Future[int]()
fut2 = Future[int]()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
@ -118,8 +118,8 @@ class TestFuture(TestCase):
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
def test_wait_all(self):
fut1 = Future()
fut2 = Future()
fut1 = Future[int]()
fut2 = Future[int]()
# No error version
fut1.set_result(1)

View File

@ -118,6 +118,14 @@ class _LegacyVariableBase(object):
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
class Future(object):
def __init__(self) -> None: ...
def wait(self) -> Any: ...
def then(self, callback: Callable) -> Future: ...
def set_result(self, result: Any) -> None: ...
def _collect_all(futures: List[Future]) -> Future: ...
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...

View File

@ -1,10 +1,28 @@
from typing import Generic, TypeVar
from typing import cast, Callable, Generic, List, Type, TypeVar
import torch
from torch._six import PY37
T = TypeVar("T")
S = TypeVar("S")
class _PyFuture(torch._C.Future):
def wait(self):
if not PY37:
# Workaround for https://github.com/python/typing/issues/449 in Python 3.6
from typing import GenericMeta
class _PyFutureMeta(type(torch._C.Future), GenericMeta): # type: ignore[misc]
pass
else:
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
pass
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
def wait(self) -> T:
r"""
Block until the value of this ``Future`` is ready.
@ -15,7 +33,8 @@ class _PyFuture(torch._C.Future):
"""
return super().wait()
def then(self, callback):
# Have to use string annotations because PEP-0563 is not available in 3.6
def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S]
r"""
Append the given callback function to this ``Future``, which will be run
when the ``Future`` is completed. Multiple callbacks can be added to
@ -52,9 +71,9 @@ class _PyFuture(torch._C.Future):
>>> # RPC return value is 5.
>>> # Chained cb done. None
"""
return super().then(callback)
return cast(Future[S], super().then(callback))
def set_result(self, result):
def set_result(self, result: T) -> None:
r"""
Set the result for this ``Future``, which will mark this ``Future`` as
completed and trigger all attached callbacks. Note that a ``Future``
@ -85,7 +104,7 @@ class _PyFuture(torch._C.Future):
super().set_result(result)
def collect_all(futures):
def collect_all(futures: List[Future]) -> Future[List[Future]]:
r"""
Collects the provided :class:`~torch.futures.Future` objects into a single
combined :class:`~torch.futures.Future` that is completed when all of the
@ -116,10 +135,10 @@ def collect_all(futures):
>>> # fut0 result = 0
>>> # fut1 result = 1
"""
return torch._C._collect_all(futures)
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
def wait_all(futures):
def wait_all(futures: List[Future]) -> List:
r"""
Waits for all provided futures to be complete, and returns
the list of completed values.
@ -132,36 +151,4 @@ def wait_all(futures):
method will throw an error if ``wait`` on any
:class:`~torch.futures.Future` throws.
"""
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]
try:
# Combine the implementation class and the type class.
class Future(_PyFuture, GenericWithOneTypeVar):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
pass
except TypeError as exc:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
class FutureMeta(_PyFuture.__class__, GenericWithOneTypeVar.__class__):
pass
# Combine the implementation class and the type class.
class Future(_PyFuture, GenericWithOneTypeVar, metaclass=FutureMeta):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
pass
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]