pytorch/torch/_C/_VariableFunctions.pyi.in
Philip Meier 9d6109c4b0 improve annotations (#86105)
In `torchvision` we started to use tensor subclasses. With the current annotations, this minimal example throws three errors when checking with `mypy`:

```py
from typing import Type, TypeVar, Any, Optional, Union

import torch

T = TypeVar("T", bound="TensorSubclass")

class TensorSubclass(torch.Tensor):
    def __new__(
        cls: Type[T],
        data: Any,
        *,
        dtype: Optional[torch.dtype] = None,
        device: Optional[Union[torch.device, str, int]] = None,
    ) -> T:
        return torch.as_tensor(data, dtype=dtype, device=device).as_subclass(cls)
```

```
main.py:16:16: error: Incompatible return value type (got "Tensor", expected "T")  [return-value]
main.py:16:58: error: Argument "device" to "as_tensor" has incompatible type "Union[device, str, int, None]"; expected "Optional[device]"  [arg-type]
main.py:16:78: error: Argument 1 to "as_subclass" of "_TensorBase" has incompatible type "Type[T]"; expected "Tensor"  [arg-type]
```

I'll explain inline why the old annotations are wrong.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86105
Approved by: https://github.com/albanD
2022-10-05 10:33:26 +00:00

16 lines
486 B
Python

# ${generated_comment}
from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
from typing_extensions import Literal
from torch._six import inf
from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
import torch
import builtins
${function_hints}
${all_directive}