Improve Tensor type hints (#28578)

Summary:
I've typed some attributes from ee920b92c4/torch/csrc/autograd/python_variable.cpp (L490) that were not included in the stubs so that MyPy will be aware of them. I made sure to only add those attributes that are mentioned somewhere in the documentation. If there are attributes mentioned in the documentation that are not meant to be part of the public API (or the opposite), please let me know. I've also made sure that attributes that can't be set are typed as read-only properties. If setting `dtype`, `shape`, `device` or `names` directly is not part of the public API, let me know and I'll make them properties as well.

I've also added `__len__`, `__iter__` and `__contains__`, which means MyPy will no longer complain about `len(t)`, `t1 in t2` and `for t1 in t2`.

Shameless plug: I have another typing-related PR here that needs review: https://github.com/pytorch/pytorch/pull/27445

Fixes https://github.com/pytorch/pytorch/issues/28457
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28578

Reviewed By: lerks

Differential Revision: D18113954

Pulled By: fmassa

fbshipit-source-id: 0b69a2966d22054d8d87392f19ec5aa3918773bc
This commit is contained in:
henribru 2019-10-27 04:40:04 -07:00 committed by Facebook Github Bot
parent 440b192078
commit 764e0ee882
2 changed files with 21 additions and 5 deletions

View File

@ -275,7 +275,7 @@ def generate_type_hints(fname, decls, is_tensor=False):
python_args.append('*')
render_kw_only_separator = False
python_args += ["dtype: _dtype=None",
"layout: layout=strided",
"layout: _layout=strided",
"device: Union[_device, str, None]=None",
"requires_grad:_bool=False"]

View File

@ -1,6 +1,6 @@
# ${generated_comment}
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator
from torch._six import inf
import builtins
@ -66,6 +66,7 @@ _dtype = dtype
_device = device
_qscheme = qscheme
_size = Union[Size, List[_int], Tuple[_int, ...]]
_layout = layout
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]
@ -83,16 +84,31 @@ class Generator:
# torch.tensor.Tensor doesn't get type annotations. Nobody
# should really do that, so maybe this is not so bad.
class Tensor:
dtype: _dtype = ...
shape: Size = ...
device: _device = ...
requires_grad: _bool = ...
grad: Optional[Tensor] = ...
data: Tensor = ...
names: List[str] = ...
@property
def dtype(self) -> _dtype: ...
@property
def shape(self) -> Size: ...
@property
def device(self) -> _device: ...
@property
def T(self) -> Tensor: ...
@property
def grad_fn(self) -> Optional[Any]: ...
@property
def ndim(self) -> _int: ...
@property
def layout(self) -> _layout: ...
${tensor_method_hints}
# Manually defined methods from torch/tensor.py
def __len__(self) -> _int: ...
def __iter__(self) -> Iterator[Tensor]: ...
def __contains__(self, item: Union[Tensor, Number]) -> _bool: ...
def register_hook(self, hook: Callable) -> Any: ...
def retain_grad(self) -> None: ...
def is_shared(self) -> _bool: ...