mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
In `torchvision` we started to use tensor subclasses. With the current annotations, this minimal example throws three errors when checking with `mypy`:
```py
from typing import Type, TypeVar, Any, Optional, Union
import torch
T = TypeVar("T", bound="TensorSubclass")
class TensorSubclass(torch.Tensor):
def __new__(
cls: Type[T],
data: Any,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str, int]] = None,
) -> T:
return torch.as_tensor(data, dtype=dtype, device=device).as_subclass(cls)
```
```
main.py:16:16: error: Incompatible return value type (got "Tensor", expected "T") [return-value]
main.py:16:58: error: Argument "device" to "as_tensor" has incompatible type "Union[device, str, int, None]"; expected "Optional[device]" [arg-type]
main.py:16:78: error: Argument 1 to "as_subclass" of "_TensorBase" has incompatible type "Type[T]"; expected "Tensor" [arg-type]
```
I'll explain inline why the old annotations are wrong.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86105
Approved by: https://github.com/albanD
16 lines
486 B
Python
16 lines
486 B
Python
# ${generated_comment}
|
|
|
|
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
|
|
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
|
|
from typing_extensions import Literal
|
|
from torch._six import inf
|
|
|
|
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
|
|
import torch
|
|
|
|
import builtins
|
|
|
|
${function_hints}
|
|
|
|
${all_directive}
|