Enable typechecking for torch.futures (#41675)

Summary:
Add typing declarations for torch._C.Future and torch._C._collect_all

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41675

Reviewed By: izdeby

Differential Revision: D22627539

Pulled By: malfet

fbshipit-source-id: 29b87685d65dd24ee2094bae8a84a0fe3787e7f8
This commit is contained in:
Nikita Shulga 2020-07-23 23:04:45 -07:00 committed by Facebook GitHub Bot
parent 750d9dea49
commit c0bfa45f9d
4 changed files with 53 additions and 60 deletions

View File

@ -19,6 +19,7 @@ files =
caffe2,
aten/src/ATen/function_wrapper.py,
test/test_complex.py,
test/test_futures.py,
test/test_torch.py,
test/test_type_hints.py,
test/test_type_info.py
@ -53,9 +54,6 @@ ignore_errors = True
[mypy-torch.functional.*]
ignore_errors = True
[mypy-torch.futures.*]
ignore_errors = True
[mypy-torch.testing._internal.*]
ignore_errors = True

View File

@ -11,19 +11,19 @@ def add_one(fut):
class TestFuture(TestCase):
def test_wait(self):
f = Future()
def test_wait(self) -> None:
f = Future[torch.Tensor]()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self):
def test_wait_multi_thread(self) -> None:
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future()
f = Future[torch.Tensor]()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
@ -31,8 +31,8 @@ class TestFuture(TestCase):
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self):
fut = Future()
def test_mark_future_twice(self) -> None:
fut = Future[int]()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
@ -41,14 +41,14 @@ class TestFuture(TestCase):
fut.set_result(1)
def test_pickle_future(self):
fut = Future()
fut = Future[int]()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future()
fut = Future[torch.Tensor]()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
@ -56,7 +56,7 @@ class TestFuture(TestCase):
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future()
fut = Future[torch.Tensor]()
futs = []
last_fut = fut
for _ in range(20):
@ -69,7 +69,7 @@ class TestFuture(TestCase):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future()
fut = Future[int]()
then_fut = fut.then(cb)
fut.set_result(5)
@ -99,8 +99,8 @@ class TestFuture(TestCase):
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future()
fut2 = Future()
fut1 = Future[int]()
fut2 = Future[int]()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
@ -118,8 +118,8 @@ class TestFuture(TestCase):
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
def test_wait_all(self):
fut1 = Future()
fut2 = Future()
fut1 = Future[int]()
fut2 = Future[int]()
# No error version
fut1.set_result(1)

View File

@ -3,7 +3,7 @@
import torch
from torch import Tensor
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
from torch._six import inf
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
@ -118,6 +118,14 @@ class _LegacyVariableBase(object):
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
class Future(object):
def __init__(self) -> None: ...
def wait(self) -> Any: ...
def then(self, callback: Callable) -> Future: ...
def set_result(self, result: Any) -> None: ...
def _collect_all(futures: List[Future]) -> Future: ...
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...

View File

@ -1,10 +1,28 @@
from typing import Generic, TypeVar
from typing import cast, Callable, Generic, List, Type, TypeVar
import torch
from torch._six import PY37
T = TypeVar("T")
S = TypeVar("S")
class _PyFuture(torch._C.Future):
def wait(self):
if not PY37:
# Workaround for https://github.com/python/typing/issues/449 in Python 3.6
from typing import GenericMeta
class _PyFutureMeta(type(torch._C.Future), GenericMeta): # type: ignore[misc]
pass
else:
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
pass
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
def wait(self) -> T:
r"""
Block until the value of this ``Future`` is ready.
@ -15,7 +33,8 @@ class _PyFuture(torch._C.Future):
"""
return super().wait()
def then(self, callback):
# Have to use string annotations because PEP-0563 is not available in 3.6
def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S]
r"""
Append the given callback function to this ``Future``, which will be run
when the ``Future`` is completed. Multiple callbacks can be added to
@ -52,9 +71,9 @@ class _PyFuture(torch._C.Future):
>>> # RPC return value is 5.
>>> # Chained cb done. None
"""
return super().then(callback)
return cast(Future[S], super().then(callback))
def set_result(self, result):
def set_result(self, result: T) -> None:
r"""
Set the result for this ``Future``, which will mark this ``Future`` as
completed and trigger all attached callbacks. Note that a ``Future``
@ -85,7 +104,7 @@ class _PyFuture(torch._C.Future):
super().set_result(result)
def collect_all(futures):
def collect_all(futures: List[Future]) -> Future[List[Future]]:
r"""
Collects the provided :class:`~torch.futures.Future` objects into a single
combined :class:`~torch.futures.Future` that is completed when all of the
@ -116,10 +135,10 @@ def collect_all(futures):
>>> # fut0 result = 0
>>> # fut1 result = 1
"""
return torch._C._collect_all(futures)
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
def wait_all(futures):
def wait_all(futures: List[Future]) -> List:
r"""
Waits for all provided futures to be complete, and returns
the list of completed values.
@ -132,36 +151,4 @@ def wait_all(futures):
method will throw an error if ``wait`` on any
:class:`~torch.futures.Future` throws.
"""
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
T = TypeVar("T")
GenericWithOneTypeVar = Generic[T]
try:
# Combine the implementation class and the type class.
class Future(_PyFuture, GenericWithOneTypeVar):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
pass
except TypeError as exc:
# TypeError: metaclass conflict: the metaclass of a derived class
# must be a (non-strict) subclass of the metaclasses of all its bases
class FutureMeta(_PyFuture.__class__, GenericWithOneTypeVar.__class__):
pass
# Combine the implementation class and the type class.
class Future(_PyFuture, GenericWithOneTypeVar, metaclass=FutureMeta):
r"""
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
also exposes a set of APIs to add callback functions and set results.
"""
pass
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]