mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable torch.tensor typechecks (#45077)
Summary: this fixes https://github.com/pytorch/pytorch/issues/42983. Pull Request resolved: https://github.com/pytorch/pytorch/pull/45077 Reviewed By: ezyang Differential Revision: D23842493 Pulled By: walterddr fbshipit-source-id: 1c516a5ff351743a187d00cba7ed0be11678edf1
This commit is contained in:
parent
dc67b47bc9
commit
bea7901e38
3
mypy.ini
3
mypy.ini
|
|
@ -102,9 +102,6 @@ ignore_errors = True
|
|||
[mypy-torch.distributions.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.tensor]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._tensor_str]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
|
|
@ -74,11 +74,7 @@ blocklist = [
|
|||
# Somehow, these are defined in both _C and in functional. Ick!
|
||||
'broadcast_tensors',
|
||||
# Manually define named tensor type stubs in __init__.pyi.in
|
||||
'rename',
|
||||
'refine_names',
|
||||
'align_to',
|
||||
'align_tensors',
|
||||
'unflatten',
|
||||
'meshgrid',
|
||||
'cartesian_prod',
|
||||
'block_diag',
|
||||
|
|
@ -87,7 +83,6 @@ blocklist = [
|
|||
'stft',
|
||||
'istft',
|
||||
'tensordot',
|
||||
'norm',
|
||||
'split',
|
||||
'unique_consecutive',
|
||||
'atleast_1d',
|
||||
|
|
@ -536,6 +531,7 @@ def gen_pyi(declarations_path, out):
|
|||
'def __init__(self, other: Tensor) -> None: ...',
|
||||
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
|
||||
],
|
||||
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
|
||||
# clamp has no default values in the Declarations
|
||||
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
|
||||
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
|
||||
|
|
@ -546,6 +542,7 @@ def gen_pyi(declarations_path, out):
|
|||
'tolist': ['def tolist(self) -> List: ...'],
|
||||
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
|
||||
'element_size': ['def element_size(self) -> _int: ...'],
|
||||
'data_ptr': ['def data_ptr(self) -> _int: ...'],
|
||||
'dim': ['def dim(self) -> _int: ...'],
|
||||
'nonzero': ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
|
||||
'numel': ['def numel(self) -> _int: ...'],
|
||||
|
|
@ -576,6 +573,10 @@ def gen_pyi(declarations_path, out):
|
|||
],
|
||||
'item': ["def item(self) -> Number: ..."],
|
||||
'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
|
||||
'set_': ['def set_(self, storage: Storage, offset: _int, size: _size, stride: _size) -> Tensor: ...',
|
||||
'def set_(self, storage: Storage) -> Tensor: ...'],
|
||||
'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
|
||||
'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
|
||||
})
|
||||
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
|
||||
for inplace in [False, True]:
|
||||
|
|
|
|||
|
|
@ -87,6 +87,9 @@ ${dtype_class_hints}
|
|||
class layout:
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/utils/disable_torch_function.cpp
|
||||
def DisableTorchFunction(): ...
|
||||
|
||||
# Defined in torch/csrc/utils/tensor_layouts.cpp
|
||||
strided : layout = ...
|
||||
sparse_coo : layout = ...
|
||||
|
|
@ -105,6 +108,10 @@ class qscheme: ...
|
|||
|
||||
# Defined in torch/csrc/utils/tensor_qschemes.cpp
|
||||
per_tensor_affine: qscheme = ...
|
||||
per_channel_affine: qscheme = ...
|
||||
per_tensor_symmetric: qscheme = ...
|
||||
per_channel_symmetric: qscheme = ...
|
||||
per_channel_affine_float_qparams: qscheme = ...
|
||||
|
||||
# Defined in torch/csrc/autograd/python_function.cpp
|
||||
class _FunctionBase(object):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import torch.utils.hooks as hooks
|
|||
import warnings
|
||||
import weakref
|
||||
from torch._C import _add_docstr
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from numbers import Number
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
|
@ -53,6 +54,8 @@ class Tensor(torch._C._TensorBase):
|
|||
else:
|
||||
new_storage = self.storage().__deepcopy__(memo)
|
||||
if self.is_quantized:
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]]
|
||||
if self.qscheme() == torch.per_tensor_affine:
|
||||
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
|
||||
elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
|
||||
|
|
@ -85,6 +88,7 @@ class Tensor(torch._C._TensorBase):
|
|||
check_serializing_named_tensor(self)
|
||||
# See Note [Don't serialize hooks]
|
||||
torch.utils.hooks.warn_if_has_hooks(self)
|
||||
backward_hooks: Dict[Any, Any] = OrderedDict()
|
||||
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
|
||||
# We considered a few options:
|
||||
# 1. CPU tensor can't be used here.
|
||||
|
|
@ -96,12 +100,14 @@ class Tensor(torch._C._TensorBase):
|
|||
# `tolist()` converts every single element in the tensor into python objects
|
||||
# and serialize them one by one.
|
||||
if self.device.type == 'xla':
|
||||
args = (self.cpu().numpy(),
|
||||
self.dtype,
|
||||
str(self.device),
|
||||
self.requires_grad)
|
||||
return (torch._utils._rebuild_xla_tensor, args)
|
||||
arg_xla = (self.cpu().numpy(),
|
||||
self.dtype,
|
||||
str(self.device),
|
||||
self.requires_grad)
|
||||
return (torch._utils._rebuild_xla_tensor, arg_xla)
|
||||
if self.is_quantized:
|
||||
# quantizer_params can be different type based on torch attribute
|
||||
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
|
||||
if self.qscheme() == torch.per_tensor_affine:
|
||||
quantizer_params = (torch.per_tensor_affine,
|
||||
self.q_scale(),
|
||||
|
|
@ -116,31 +122,31 @@ class Tensor(torch._C._TensorBase):
|
|||
self.q_per_channel_axis())
|
||||
else:
|
||||
raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}")
|
||||
args = (self.storage(),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
self.stride(),
|
||||
quantizer_params,
|
||||
self.requires_grad,
|
||||
OrderedDict())
|
||||
return (torch._utils._rebuild_qtensor, args)
|
||||
args_qtensor = (self.storage(),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
self.stride(),
|
||||
quantizer_params,
|
||||
self.requires_grad,
|
||||
backward_hooks)
|
||||
return (torch._utils._rebuild_qtensor, args_qtensor)
|
||||
elif self.is_sparse:
|
||||
if self.layout == torch.sparse_coo:
|
||||
args = (self.layout,
|
||||
(self._indices(),
|
||||
self._values(),
|
||||
self.size()))
|
||||
args_sparse = (self.layout,
|
||||
(self._indices(),
|
||||
self._values(),
|
||||
self.size()))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout))
|
||||
return (torch._utils._rebuild_sparse_tensor, args)
|
||||
return (torch._utils._rebuild_sparse_tensor, args_sparse)
|
||||
else:
|
||||
args = (self.storage(),
|
||||
self.storage_offset(),
|
||||
tuple(self.size()),
|
||||
self.stride(),
|
||||
self.requires_grad,
|
||||
OrderedDict()) # previously was self._backward_hooks
|
||||
backward_hooks) # previously was self._backward_hooks
|
||||
return (torch._utils._rebuild_tensor_v2, args)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
@ -528,7 +534,7 @@ class Tensor(torch._C._TensorBase):
|
|||
return self.item().__format__(format_spec)
|
||||
return object.__format__(self, format_spec)
|
||||
|
||||
def __ipow__(self, other):
|
||||
def __ipow__(self, other): # type: ignore[misc]
|
||||
relevant_args = (self, other)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
|
||||
|
|
@ -652,7 +658,8 @@ class Tensor(torch._C._TensorBase):
|
|||
if type(self) is not Tensor and has_torch_function(relevant_args):
|
||||
return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
|
||||
if isinstance(element, (torch.Tensor, Number)):
|
||||
return (element == self).any().item()
|
||||
# type hint doesn't understand the __contains__ result array
|
||||
return (element == self).any().item() # type: ignore[union-attr]
|
||||
|
||||
raise RuntimeError(
|
||||
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
|
||||
|
|
@ -669,7 +676,8 @@ class Tensor(torch._C._TensorBase):
|
|||
relevant_args = (self,)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
if type(self) is not Tensor and has_torch_function(relevant_args):
|
||||
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self)
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined]
|
||||
|
||||
# raise AttributeError for unsupported tensors, so that
|
||||
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
|
||||
|
|
@ -936,7 +944,8 @@ class Tensor(torch._C._TensorBase):
|
|||
relevant_args = (self,)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
if type(self) is not Tensor and has_torch_function(relevant_args):
|
||||
return handle_torch_function(Tensor.grad.__get__, relevant_args, self)
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined]
|
||||
|
||||
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
|
||||
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
|
||||
|
|
@ -951,7 +960,8 @@ class Tensor(torch._C._TensorBase):
|
|||
relevant_args = (self,)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
if type(self) is not Tensor and has_torch_function(relevant_args):
|
||||
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad)
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined]
|
||||
self._grad = new_grad
|
||||
|
||||
@grad.deleter
|
||||
|
|
@ -959,7 +969,8 @@ class Tensor(torch._C._TensorBase):
|
|||
relevant_args = (self,)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
if type(self) is not Tensor and has_torch_function(relevant_args):
|
||||
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self)
|
||||
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
|
||||
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined]
|
||||
del self._grad
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -34,13 +34,25 @@ Device = Union[_device, str, None]
|
|||
class Storage(object):
|
||||
_cdata: int
|
||||
|
||||
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
|
||||
...
|
||||
|
||||
def size(self) -> int:
|
||||
def __deepcopy__(self, memo) -> 'Storage':
|
||||
...
|
||||
|
||||
def _new_shared(self, int) -> 'Storage':
|
||||
...
|
||||
|
||||
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
|
||||
...
|
||||
|
||||
def element_size(self) -> int:
|
||||
...
|
||||
|
||||
def is_shared(self) -> bool:
|
||||
...
|
||||
|
||||
def share_memory_(self) -> 'Storage':
|
||||
...
|
||||
|
||||
def size(self) -> int:
|
||||
...
|
||||
|
||||
...
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user