Enable torch.tensor typechecks (#45077)

Summary:
this fixes https://github.com/pytorch/pytorch/issues/42983.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45077

Reviewed By: ezyang

Differential Revision: D23842493

Pulled By: walterddr

fbshipit-source-id: 1c516a5ff351743a187d00cba7ed0be11678edf1
This commit is contained in:
Rong Rong 2020-09-24 08:20:06 -07:00 committed by Facebook GitHub Bot
parent dc67b47bc9
commit bea7901e38
5 changed files with 65 additions and 37 deletions

View File

@ -102,9 +102,6 @@ ignore_errors = True
[mypy-torch.distributions.*]
ignore_errors = True
[mypy-torch.tensor]
ignore_errors = True
[mypy-torch._tensor_str]
ignore_errors = True

View File

@ -74,11 +74,7 @@ blocklist = [
# Somehow, these are defined in both _C and in functional. Ick!
'broadcast_tensors',
# Manually define named tensor type stubs in __init__.pyi.in
'rename',
'refine_names',
'align_to',
'align_tensors',
'unflatten',
'meshgrid',
'cartesian_prod',
'block_diag',
@ -87,7 +83,6 @@ blocklist = [
'stft',
'istft',
'tensordot',
'norm',
'split',
'unique_consecutive',
'atleast_1d',
@ -536,6 +531,7 @@ def gen_pyi(declarations_path, out):
'def __init__(self, other: Tensor) -> None: ...',
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
],
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
@ -546,6 +542,7 @@ def gen_pyi(declarations_path, out):
'tolist': ['def tolist(self) -> List: ...'],
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
'element_size': ['def element_size(self) -> _int: ...'],
'data_ptr': ['def data_ptr(self) -> _int: ...'],
'dim': ['def dim(self) -> _int: ...'],
'nonzero': ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
'numel': ['def numel(self) -> _int: ...'],
@ -576,6 +573,10 @@ def gen_pyi(declarations_path, out):
],
'item': ["def item(self) -> Number: ..."],
'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
'set_': ['def set_(self, storage: Storage, offset: _int, size: _size, stride: _size) -> Tensor: ...',
'def set_(self, storage: Storage) -> Tensor: ...'],
'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
})
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
for inplace in [False, True]:

View File

@ -87,6 +87,9 @@ ${dtype_class_hints}
class layout:
...
# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunction(): ...
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
sparse_coo : layout = ...
@ -105,6 +108,10 @@ class qscheme: ...
# Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ...
per_channel_affine: qscheme = ...
per_tensor_symmetric: qscheme = ...
per_channel_symmetric: qscheme = ...
per_channel_affine_float_qparams: qscheme = ...
# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):

View File

@ -7,6 +7,7 @@ import torch.utils.hooks as hooks
import warnings
import weakref
from torch._C import _add_docstr
from typing import Any, Dict, Tuple, Union
from numbers import Number
import functools
from typing import Optional
@ -53,6 +54,8 @@ class Tensor(torch._C._TensorBase):
else:
new_storage = self.storage().__deepcopy__(memo)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
@ -85,6 +88,7 @@ class Tensor(torch._C._TensorBase):
check_serializing_named_tensor(self)
# See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
# We considered a few options:
# 1. CPU tensor can't be used here.
@ -96,12 +100,14 @@ class Tensor(torch._C._TensorBase):
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
if self.device.type == 'xla':
args = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, args)
arg_xla = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, arg_xla)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (torch.per_tensor_affine,
self.q_scale(),
@ -116,31 +122,31 @@ class Tensor(torch._C._TensorBase):
self.q_per_channel_axis())
else:
raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}")
args = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
quantizer_params,
self.requires_grad,
OrderedDict())
return (torch._utils._rebuild_qtensor, args)
args_qtensor = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
quantizer_params,
self.requires_grad,
backward_hooks)
return (torch._utils._rebuild_qtensor, args_qtensor)
elif self.is_sparse:
if self.layout == torch.sparse_coo:
args = (self.layout,
(self._indices(),
self._values(),
self.size()))
args_sparse = (self.layout,
(self._indices(),
self._values(),
self.size()))
else:
raise NotImplementedError(
'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout))
return (torch._utils._rebuild_sparse_tensor, args)
return (torch._utils._rebuild_sparse_tensor, args_sparse)
else:
args = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
self.requires_grad,
OrderedDict()) # previously was self._backward_hooks
backward_hooks) # previously was self._backward_hooks
return (torch._utils._rebuild_tensor_v2, args)
def __setstate__(self, state):
@ -528,7 +534,7 @@ class Tensor(torch._C._TensorBase):
return self.item().__format__(format_spec)
return object.__format__(self, format_spec)
def __ipow__(self, other):
def __ipow__(self, other): # type: ignore[misc]
relevant_args = (self, other)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
@ -652,7 +658,8 @@ class Tensor(torch._C._TensorBase):
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
if isinstance(element, (torch.Tensor, Number)):
return (element == self).any().item()
# type hint doesn't understand the __contains__ result array
return (element == self).any().item() # type: ignore[union-attr]
raise RuntimeError(
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
@ -669,7 +676,8 @@ class Tensor(torch._C._TensorBase):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined]
# raise AttributeError for unsupported tensors, so that
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
@ -936,7 +944,8 @@ class Tensor(torch._C._TensorBase):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__get__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined]
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
@ -951,7 +960,8 @@ class Tensor(torch._C._TensorBase):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined]
self._grad = new_grad
@grad.deleter
@ -959,7 +969,8 @@ class Tensor(torch._C._TensorBase):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined]
del self._grad
@classmethod

View File

@ -34,13 +34,25 @@ Device = Union[_device, str, None]
class Storage(object):
_cdata: int
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
...
def size(self) -> int:
def __deepcopy__(self, memo) -> 'Storage':
...
def _new_shared(self, int) -> 'Storage':
...
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
...
def element_size(self) -> int:
...
def is_shared(self) -> bool:
...
def share_memory_(self) -> 'Storage':
...
def size(self) -> int:
...
...