mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable typechecking for torch.futures (#41675)
Summary: Add typing declarations for torch._C.Future and torch._C._collect_all Pull Request resolved: https://github.com/pytorch/pytorch/pull/41675 Reviewed By: izdeby Differential Revision: D22627539 Pulled By: malfet fbshipit-source-id: 29b87685d65dd24ee2094bae8a84a0fe3787e7f8
This commit is contained in:
parent
750d9dea49
commit
c0bfa45f9d
4
mypy.ini
4
mypy.ini
|
|
@ -19,6 +19,7 @@ files =
|
|||
caffe2,
|
||||
aten/src/ATen/function_wrapper.py,
|
||||
test/test_complex.py,
|
||||
test/test_futures.py,
|
||||
test/test_torch.py,
|
||||
test/test_type_hints.py,
|
||||
test/test_type_info.py
|
||||
|
|
@ -53,9 +54,6 @@ ignore_errors = True
|
|||
[mypy-torch.functional.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.futures.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -11,19 +11,19 @@ def add_one(fut):
|
|||
|
||||
|
||||
class TestFuture(TestCase):
|
||||
def test_wait(self):
|
||||
f = Future()
|
||||
def test_wait(self) -> None:
|
||||
f = Future[torch.Tensor]()
|
||||
f.set_result(torch.ones(2, 2))
|
||||
|
||||
self.assertEqual(f.wait(), torch.ones(2, 2))
|
||||
|
||||
def test_wait_multi_thread(self):
|
||||
def test_wait_multi_thread(self) -> None:
|
||||
|
||||
def slow_set_future(fut, value):
|
||||
time.sleep(0.5)
|
||||
fut.set_result(value)
|
||||
|
||||
f = Future()
|
||||
f = Future[torch.Tensor]()
|
||||
|
||||
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
|
||||
t.start()
|
||||
|
|
@ -31,8 +31,8 @@ class TestFuture(TestCase):
|
|||
self.assertEqual(f.wait(), torch.ones(2, 2))
|
||||
t.join()
|
||||
|
||||
def test_mark_future_twice(self):
|
||||
fut = Future()
|
||||
def test_mark_future_twice(self) -> None:
|
||||
fut = Future[int]()
|
||||
fut.set_result(1)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
|
|
@ -41,14 +41,14 @@ class TestFuture(TestCase):
|
|||
fut.set_result(1)
|
||||
|
||||
def test_pickle_future(self):
|
||||
fut = Future()
|
||||
fut = Future[int]()
|
||||
errMsg = "Can not pickle torch.futures.Future"
|
||||
with TemporaryFileName() as fname:
|
||||
with self.assertRaisesRegex(RuntimeError, errMsg):
|
||||
torch.save(fut, fname)
|
||||
|
||||
def test_then(self):
|
||||
fut = Future()
|
||||
fut = Future[torch.Tensor]()
|
||||
then_fut = fut.then(lambda x: x.wait() + 1)
|
||||
|
||||
fut.set_result(torch.ones(2, 2))
|
||||
|
|
@ -56,7 +56,7 @@ class TestFuture(TestCase):
|
|||
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
|
||||
|
||||
def test_chained_then(self):
|
||||
fut = Future()
|
||||
fut = Future[torch.Tensor]()
|
||||
futs = []
|
||||
last_fut = fut
|
||||
for _ in range(20):
|
||||
|
|
@ -69,7 +69,7 @@ class TestFuture(TestCase):
|
|||
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
|
||||
|
||||
def _test_error(self, cb, errMsg):
|
||||
fut = Future()
|
||||
fut = Future[int]()
|
||||
then_fut = fut.then(cb)
|
||||
|
||||
fut.set_result(5)
|
||||
|
|
@ -99,8 +99,8 @@ class TestFuture(TestCase):
|
|||
self._test_error(raise_value_error, "Expected error")
|
||||
|
||||
def test_collect_all(self):
|
||||
fut1 = Future()
|
||||
fut2 = Future()
|
||||
fut1 = Future[int]()
|
||||
fut2 = Future[int]()
|
||||
fut_all = torch.futures.collect_all([fut1, fut2])
|
||||
|
||||
def slow_in_thread(fut, value):
|
||||
|
|
@ -118,8 +118,8 @@ class TestFuture(TestCase):
|
|||
|
||||
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
|
||||
def test_wait_all(self):
|
||||
fut1 = Future()
|
||||
fut2 = Future()
|
||||
fut1 = Future[int]()
|
||||
fut2 = Future[int]()
|
||||
|
||||
# No error version
|
||||
fut1.set_result(1)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import (Any, BinaryIO, Callable, ContextManager, Iterator, List, NamedTuple,
|
||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
||||
Optional, overload, Sequence, Tuple, TypeVar, Type, Union)
|
||||
from torch._six import inf
|
||||
|
||||
from torch.types import _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage
|
||||
|
|
@ -118,6 +118,14 @@ class _LegacyVariableBase(object):
|
|||
) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/init.cpp
|
||||
class Future(object):
|
||||
def __init__(self) -> None: ...
|
||||
def wait(self) -> Any: ...
|
||||
def then(self, callback: Callable) -> Future: ...
|
||||
def set_result(self, result: Any) -> None: ...
|
||||
|
||||
|
||||
def _collect_all(futures: List[Future]) -> Future: ...
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,28 @@
|
|||
from typing import Generic, TypeVar
|
||||
from typing import cast, Callable, Generic, List, Type, TypeVar
|
||||
|
||||
import torch
|
||||
from torch._six import PY37
|
||||
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
class _PyFuture(torch._C.Future):
|
||||
def wait(self):
|
||||
if not PY37:
|
||||
# Workaround for https://github.com/python/typing/issues/449 in Python 3.6
|
||||
from typing import GenericMeta
|
||||
|
||||
class _PyFutureMeta(type(torch._C.Future), GenericMeta): # type: ignore[misc]
|
||||
pass
|
||||
else:
|
||||
class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef]
|
||||
pass
|
||||
|
||||
class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
|
||||
r"""
|
||||
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
|
||||
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||||
also exposes a set of APIs to add callback functions and set results.
|
||||
"""
|
||||
def wait(self) -> T:
|
||||
r"""
|
||||
Block until the value of this ``Future`` is ready.
|
||||
|
||||
|
|
@ -15,7 +33,8 @@ class _PyFuture(torch._C.Future):
|
|||
"""
|
||||
return super().wait()
|
||||
|
||||
def then(self, callback):
|
||||
# Have to use string annotations because PEP-0563 is not available in 3.6
|
||||
def then(self, callback): # type: (Callable[[Future[T]], S]) -> Future[S]
|
||||
r"""
|
||||
Append the given callback function to this ``Future``, which will be run
|
||||
when the ``Future`` is completed. Multiple callbacks can be added to
|
||||
|
|
@ -52,9 +71,9 @@ class _PyFuture(torch._C.Future):
|
|||
>>> # RPC return value is 5.
|
||||
>>> # Chained cb done. None
|
||||
"""
|
||||
return super().then(callback)
|
||||
return cast(Future[S], super().then(callback))
|
||||
|
||||
def set_result(self, result):
|
||||
def set_result(self, result: T) -> None:
|
||||
r"""
|
||||
Set the result for this ``Future``, which will mark this ``Future`` as
|
||||
completed and trigger all attached callbacks. Note that a ``Future``
|
||||
|
|
@ -85,7 +104,7 @@ class _PyFuture(torch._C.Future):
|
|||
super().set_result(result)
|
||||
|
||||
|
||||
def collect_all(futures):
|
||||
def collect_all(futures: List[Future]) -> Future[List[Future]]:
|
||||
r"""
|
||||
Collects the provided :class:`~torch.futures.Future` objects into a single
|
||||
combined :class:`~torch.futures.Future` that is completed when all of the
|
||||
|
|
@ -116,10 +135,10 @@ def collect_all(futures):
|
|||
>>> # fut0 result = 0
|
||||
>>> # fut1 result = 1
|
||||
"""
|
||||
return torch._C._collect_all(futures)
|
||||
return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
|
||||
|
||||
|
||||
def wait_all(futures):
|
||||
def wait_all(futures: List[Future]) -> List:
|
||||
r"""
|
||||
Waits for all provided futures to be complete, and returns
|
||||
the list of completed values.
|
||||
|
|
@ -132,36 +151,4 @@ def wait_all(futures):
|
|||
method will throw an error if ``wait`` on any
|
||||
:class:`~torch.futures.Future` throws.
|
||||
"""
|
||||
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
GenericWithOneTypeVar = Generic[T]
|
||||
|
||||
|
||||
try:
|
||||
|
||||
# Combine the implementation class and the type class.
|
||||
class Future(_PyFuture, GenericWithOneTypeVar):
|
||||
r"""
|
||||
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
|
||||
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||||
also exposes a set of APIs to add callback functions and set results.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
except TypeError as exc:
|
||||
# TypeError: metaclass conflict: the metaclass of a derived class
|
||||
# must be a (non-strict) subclass of the metaclasses of all its bases
|
||||
class FutureMeta(_PyFuture.__class__, GenericWithOneTypeVar.__class__):
|
||||
pass
|
||||
|
||||
# Combine the implementation class and the type class.
|
||||
class Future(_PyFuture, GenericWithOneTypeVar, metaclass=FutureMeta):
|
||||
r"""
|
||||
Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
|
||||
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
|
||||
also exposes a set of APIs to add callback functions and set results.
|
||||
"""
|
||||
pass
|
||||
return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user