mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Improve Tensor type hints (#28578)
Summary:
I've typed some attributes from ee920b92c4/torch/csrc/autograd/python_variable.cpp (L490) that were not included in the stubs so that MyPy will be aware of them. I made sure to only add those attributes that are mentioned somewhere in the documentation. If there are attributes mentioned in the documentation that are not meant to be part of the public API (or the opposite), please let me know. I've also made sure that attributes that can't be set are typed as read-only properties. If setting `dtype`, `shape`, `device` or `names` directly is not part of the public API, let me know and I'll make them properties as well.
I've also added `__len__`, `__iter__` and `__contains__`, which means MyPy will no longer complain about `len(t)`, `t1 in t2` and `for t1 in t2`.
Shameless plug: I have another typing-related PR here that needs review: https://github.com/pytorch/pytorch/pull/27445
Fixes https://github.com/pytorch/pytorch/issues/28457
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28578
Reviewed By: lerks
Differential Revision: D18113954
Pulled By: fmassa
fbshipit-source-id: 0b69a2966d22054d8d87392f19ec5aa3918773bc
This commit is contained in:
parent
440b192078
commit
764e0ee882
|
|
@ -275,7 +275,7 @@ def generate_type_hints(fname, decls, is_tensor=False):
|
||||||
python_args.append('*')
|
python_args.append('*')
|
||||||
render_kw_only_separator = False
|
render_kw_only_separator = False
|
||||||
python_args += ["dtype: _dtype=None",
|
python_args += ["dtype: _dtype=None",
|
||||||
"layout: layout=strided",
|
"layout: _layout=strided",
|
||||||
"device: Union[_device, str, None]=None",
|
"device: Union[_device, str, None]=None",
|
||||||
"requires_grad:_bool=False"]
|
"requires_grad:_bool=False"]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# ${generated_comment}
|
# ${generated_comment}
|
||||||
|
|
||||||
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload
|
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator
|
||||||
from torch._six import inf
|
from torch._six import inf
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
|
@ -66,6 +66,7 @@ _dtype = dtype
|
||||||
_device = device
|
_device = device
|
||||||
_qscheme = qscheme
|
_qscheme = qscheme
|
||||||
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
_size = Union[Size, List[_int], Tuple[_int, ...]]
|
||||||
|
_layout = layout
|
||||||
|
|
||||||
# Meta-type for "numeric" things; matches our docs
|
# Meta-type for "numeric" things; matches our docs
|
||||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||||
|
|
@ -83,16 +84,31 @@ class Generator:
|
||||||
# torch.tensor.Tensor doesn't get type annotations. Nobody
|
# torch.tensor.Tensor doesn't get type annotations. Nobody
|
||||||
# should really do that, so maybe this is not so bad.
|
# should really do that, so maybe this is not so bad.
|
||||||
class Tensor:
|
class Tensor:
|
||||||
dtype: _dtype = ...
|
|
||||||
shape: Size = ...
|
|
||||||
device: _device = ...
|
|
||||||
requires_grad: _bool = ...
|
requires_grad: _bool = ...
|
||||||
grad: Optional[Tensor] = ...
|
grad: Optional[Tensor] = ...
|
||||||
data: Tensor = ...
|
data: Tensor = ...
|
||||||
|
names: List[str] = ...
|
||||||
|
@property
|
||||||
|
def dtype(self) -> _dtype: ...
|
||||||
|
@property
|
||||||
|
def shape(self) -> Size: ...
|
||||||
|
@property
|
||||||
|
def device(self) -> _device: ...
|
||||||
|
@property
|
||||||
|
def T(self) -> Tensor: ...
|
||||||
|
@property
|
||||||
|
def grad_fn(self) -> Optional[Any]: ...
|
||||||
|
@property
|
||||||
|
def ndim(self) -> _int: ...
|
||||||
|
@property
|
||||||
|
def layout(self) -> _layout: ...
|
||||||
|
|
||||||
${tensor_method_hints}
|
${tensor_method_hints}
|
||||||
|
|
||||||
# Manually defined methods from torch/tensor.py
|
# Manually defined methods from torch/tensor.py
|
||||||
|
def __len__(self) -> _int: ...
|
||||||
|
def __iter__(self) -> Iterator[Tensor]: ...
|
||||||
|
def __contains__(self, item: Union[Tensor, Number]) -> _bool: ...
|
||||||
def register_hook(self, hook: Callable) -> Any: ...
|
def register_hook(self, hook: Callable) -> Any: ...
|
||||||
def retain_grad(self) -> None: ...
|
def retain_grad(self) -> None: ...
|
||||||
def is_shared(self) -> _bool: ...
|
def is_shared(self) -> _bool: ...
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user