import argparse import collections from pprint import pformat from typing import Dict, List, Sequence from torchgen.api.python import ( PythonSignatureGroup, PythonSignatureNativeFunctionPair, returns_named_tuple_pyi, ) from torchgen.gen import parse_native_yaml, parse_tags_yaml from torchgen.model import DispatchKey, Variant from torchgen.utils import FileManager from tools.autograd.gen_python_functions import ( group_overloads, load_signatures, should_generate_py_binding, ) """ This module implements generation of type stubs for PyTorch, enabling use of autocomplete in IDEs like PyCharm, which otherwise don't understand C extension modules. At the moment, this module only handles type stubs for torch and torch.Tensor. It should eventually be expanded to cover all functions which come are autogenerated. Here's our general strategy: - We start off with a hand-written __init__.pyi.in file. This file contains type definitions for everything we cannot automatically generate, including pure Python definitions directly in __init__.py (the latter case should be pretty rare). - We go through automatically bound functions based on the type information recorded in native_functions.yaml and generate type hints for them (generate_type_hints) There are a number of type hints which we've special-cased; read gen_pyi for the gory details. """ def get_py_torch_functions( python_funcs: Sequence[PythonSignatureNativeFunctionPair], method: bool = False, ) -> Sequence[PythonSignatureGroup]: """ Get declarations (grouped by name) which should be generated as either functions in the "torch" module or methods on Tensor. """ def should_bind_function(python_func: PythonSignatureNativeFunctionPair) -> bool: return ( should_generate_py_binding(python_func.function) and not python_func.function.python_module and Variant.function in python_func.function.variants ) def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: return ( should_generate_py_binding(python_func.function) and not python_func.function.python_module and Variant.method in python_func.function.variants ) should_bind = should_bind_method if method else should_bind_function return group_overloads([f for f in python_funcs if should_bind(f)]) # TODO: Consider defining some aliases for our Union[...] types, to make # the stubs to read on the human eye. DEVICE_PARAM = "device: Device = None" FACTORY_PARAMS = ( f"dtype: Optional[_dtype] = None, {DEVICE_PARAM}, requires_grad: _bool = False" ) # this could be more precise w.r.t list contents etc. How to do Ellipsis? INDICES = "indices: Union[None, _int, slice, Tensor, List, Tuple]" blocklist = [ "__init_subclass__", "__new__", "__subclasshook__", "cdist", "device", "grad", "requires_grad", "range", # defined in functional "einsum", # Somehow, these are defined in both _C and in functional. Ick! "broadcast_tensors", # Manually define named tensor type stubs in __init__.pyi.in "align_tensors", "meshgrid", "cartesian_prod", "block_diag", "norm", "chain_matmul", "stft", "tensordot", "split", "unique_consecutive", "atleast_1d", "atleast_2d", "atleast_3d", # These are handled specially by python_arg_parser.cpp "add", "add_", "add_out", "sub", "sub_", "sub_out", "mul", "mul_", "mul_out", "div", "div_", "div_out", "true_divide", "true_divide_", "true_divide_out", "floor_divide", "floor_divide_", "floor_divide_out", "to", "_to_copy", "copy_", ] binary_ops = ( "add", "sub", "mul", "div", "pow", "lshift", "rshift", "mod", "truediv", "matmul", "floordiv", "radd", "rsub", "rmul", "rtruediv", "rfloordiv", "rpow", # reverse arithmetic "and", "or", "xor", "rand", "ror", "rxor", # logic "iadd", "iand", "idiv", "ilshift", "imul", "ior", "irshift", "isub", "ixor", "ifloordiv", "imod", # inplace ops ) symmetric_comparison_ops = ("eq", "ne") asymmetric_comparison_ops = ("ge", "gt", "lt", "le") comparison_ops = symmetric_comparison_ops + asymmetric_comparison_ops unary_ops = ("neg", "abs", "invert") to_py_type_ops = ("bool", "float", "complex", "long", "index", "int", "nonzero") all_ops = binary_ops + comparison_ops + unary_ops + to_py_type_ops def sig_for_ops(opname: str) -> List[str]: """sig_for_ops(opname : str) -> List[str] Returns signatures for operator special functions (__add__ etc.)""" # we have to do this by hand, because they are hand-bound in Python assert opname.endswith("__") and opname.startswith("__"), f"Unexpected op {opname}" name = opname[2:-2] if name in binary_ops: return [f"def {opname}(self, other: Any) -> Tensor: ..."] elif name in comparison_ops: sig = f"def {opname}(self, other: Any) -> Tensor: ..." if name in symmetric_comparison_ops: # unsafe override https://github.com/python/mypy/issues/5704 sig += " # type: ignore[override]" return [sig] elif name in unary_ops: return [f"def {opname}(self) -> Tensor: ..."] elif name in to_py_type_ops: if name in {"bool", "float", "complex"}: tname = name elif name == "nonzero": tname = "bool" else: tname = "int" if tname in {"float", "int", "bool", "complex"}: tname = "builtins." + tname return [f"def {opname}(self) -> {tname}: ..."] else: raise Exception("unknown op", opname) def generate_type_hints(sig_group: PythonSignatureGroup) -> List[str]: type_hints: List[str] = [] # Some deprecated ops that are on the blocklist are still included in pyi if sig_group.signature.name in blocklist and not sig_group.signature.deprecated: return type_hints # deprecated signatures have separate entries for their functional and out variants # (as opposed to the native ops, which fuse the two into a single signature). # generate the functional variant here, if an out variant exists. if sig_group.signature.deprecated and sig_group.outplace is not None: type_hint = sig_group.signature.signature_str_pyi(skip_outputs=True) type_hints.append(type_hint) # PythonSignatureGroups that have both a functional + out variant get a single signature, with an optional out argument # Generates the out variant if one exists. Otherwise, generate the functional variant type_hint = sig_group.signature.signature_str_pyi( skip_outputs=sig_group.outplace is None ) type_hints.append(type_hint) # Some operators also additionally have a vararg variant of their signature type_hint_vararg = sig_group.signature.signature_str_pyi_vararg( skip_outputs=sig_group.outplace is None ) if type_hint_vararg: type_hints.append(type_hint_vararg) return type_hints def get_max_pool_dispatch(name: str, arg_list: List[str]) -> Dict[str, List[str]]: flag_pos = arg_list.index("{return_indices}") # If return_indices is positional arg, everything before should have no default arg_list_positional = ( [ ", ".join(single_arg.split(" = ")[0] for single_arg in arg.split(", ")) for arg in arg_list[: flag_pos + 1] ] + ["/"] + arg_list[flag_pos + 1 :] ) # Otherwise force return_indices to be kwarg arg_list_keyword = arg_list.copy() arg_list_keyword.insert(flag_pos, "*") tmpl = "def {name}({args}) -> {{return_type}}: ..." return { name: [ tmpl.format(name=name, args=", ".join(arg_list)).format( return_indices="return_indices: Literal[False] = False", return_type="Tensor", ), tmpl.format(name=name, args=", ".join(arg_list_positional)).format( return_indices="return_indices: Literal[True]", return_type="Tuple[Tensor, Tensor]", ), tmpl.format(name=name, args=", ".join(arg_list_keyword)).format( return_indices="return_indices: Literal[True]", return_type="Tuple[Tensor, Tensor]", ), ] } def gen_nn_functional(fm: FileManager) -> None: INPUT = "input: Tensor" KERNEL_SIZE = "kernel_size: Union[_int, _size]" STRIDE_PADDING = ", ".join( [ "stride: Optional[Union[_int, _size]] = None", "padding: Union[_int, _size] = 0", ] ) # TODO the list for `torch._C._nn` is nonexhaustive unsorted_c_nn_function_hints: Dict[str, List[str]] = {} for d in (2, 3): unsorted_c_nn_function_hints.update( { f"avg_pool{d}d": [ f"def avg_pool{d}d({{}}) -> Tensor: ...".format( ", ".join( [ f"{INPUT}", f"{KERNEL_SIZE}", f"{STRIDE_PADDING}", "ceil_mode: bool = False", "count_include_pad: bool = True", "divisor_override: Optional[int] = None", ] ) ) ], f"fractional_max_pool{d}d": [ f"def fractional_max_pool{d}d({{}}) -> {{}}: ...".format( ", ".join( [ f"{INPUT}", f"{KERNEL_SIZE}", "output_size: Union[_int, _size]", "_random_samples: Tensor", ] ), "Tuple[Tensor, Tensor]", ) ], f"adaptive_max_pool{d}d": [ f"def adaptive_max_pool{d}d({{}}) -> {{}}: ...".format( ", ".join([f"{INPUT}", "output_size: Union[_int, _size]"]), "Tuple[Tensor, Tensor]", ) ], } ) unsorted_c_nn_function_hints.update( { "hardtanh": [ "def hardtanh({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "min_val: float = ...", "max_val: float = ...", "*", "out: Optional[Tensor] = None", ] ) ) ], "hardtanh_": [ "def hardtanh_({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "min_val: float = ...", "max_val: float = ...", ] ) ) ], "elu_": ["def elu_(input: Tensor, alpha: float = ...) -> Tensor: ..."], "leaky_relu": [ "def leaky_relu({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "negative_slope: float = ...", "*", "out: Optional[Tensor] = None", ] ) ) ], "leaky_relu_": [ f"def leaky_relu_({', '.join(['input: Tensor', 'negative_slope: float = ...'])}) -> Tensor: ..." ], "log_sigmoid": ["def log_sigmoid(input: Tensor) -> Tensor: ..."], "gelu": ["def gelu(input: Tensor, approximate: str = ...) -> Tensor: ..."], "softplus": [ "def softplus({}) -> Tensor: ...".format( ", ".join( ["input: Tensor", "beta: int = ...", "threshold: int = ..."] ) ) ], "softshrink": [ "def softshrink(input: Tensor, lambd: float = ...) -> Tensor: ..." ], "hardsigmoid": [ f"def hardsigmoid({', '.join(['input: Tensor', '*', 'out: Optional[Tensor] = None'])}) -> Tensor: ..." ], "linear": [ "def linear({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "weight: Tensor", "bias: Optional[Tensor] = None", ] ) ) ], "pad": [ "def pad({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "pad: Sequence[int]", "mode: str = ...", "value: Optional[float] = None", ] ) ) ], "one_hot": [ "def one_hot(tensor: Tensor, num_classes: int = ...) -> Tensor: ..." ], "scaled_dot_product_attention": [ "def scaled_dot_product_attention({}) -> Tensor: ...".format( ", ".join( [ "query: Tensor", "key: Tensor", "value: Tensor", "attn_mask: Optional[Tensor] = None", "dropout_p: float = 0.0", "is_causal: bool = False", "scale: Optional[float] = None", ] ) ) ], } ) c_nn_function_hints: List[str] = [] for _, hints in sorted(unsorted_c_nn_function_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] c_nn_function_hints += hints # Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered # through an `_add_docstr` call torch_imports = [ "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", "conv_tbc", "avg_pool1d", "adaptive_avg_pool1d", "relu_", "selu_", "celu_", "prelu", "rrelu_", "hardshrink", "bilinear", "pixel_shuffle", "pixel_unshuffle", "channel_shuffle", "native_channel_shuffle", "pairwise_distance", "pdist", "cosine_similarity", ] imported_hints = [f"from .. import {_} as {_}" for _ in torch_imports] # Functions imported into `torch.nn.functional` from `torch._C._nn` c_nn_imports = [ "avg_pool2d", "avg_pool3d", "hardtanh_", "elu_", "leaky_relu_", "gelu", "softplus", "softshrink", "linear", "pad", "one_hot", "scaled_dot_product_attention", ] imported_hints += [f"from .._C._nn import {_} as {_}" for _ in c_nn_imports] # This is from `torch._C._nn` but renamed imported_hints.append("from .._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid") # Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional` unsorted_dispatched_hints: Dict[str, List[str]] = {} for d in (1, 2, 3): unsorted_dispatched_hints.update( **get_max_pool_dispatch( f"max_pool{d}d", [ f"{INPUT}", f"{KERNEL_SIZE}", f"{STRIDE_PADDING}", "dilation: Union[_int, _size] = 1", "ceil_mode: bool = False", "{return_indices}", ], ), **get_max_pool_dispatch( f"fractional_max_pool{d}d", [ f"{INPUT}", f"{KERNEL_SIZE}", "output_size: Optional[Union[_int, _size]] = None", "output_ratio: Optional[_ratio_any_t] = None", "{return_indices}", "_random_samples: Optional[Tensor] = None", ], ), **get_max_pool_dispatch( f"adaptive_max_pool{d}d", [f"{INPUT}", "output_size: Union[_int, _size]", "{return_indices}"], ), ) # There's no fractional_max_pool1d del unsorted_dispatched_hints["fractional_max_pool1d"] dispatched_hints: List[str] = [] for _, hints in sorted(unsorted_dispatched_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] dispatched_hints += hints fm.write_with_template( "torch/nn/functional.pyi", "torch/nn/functional.pyi.in", lambda: { "imported_hints": imported_hints, "dispatched_hints": dispatched_hints, }, ) fm.write_with_template( "torch/_C/_nn.pyi", "torch/_C/_nn.pyi.in", lambda: { "c_nn_function_hints": c_nn_function_hints, }, ) def gen_pyi( native_yaml_path: str, tags_yaml_path: str, deprecated_yaml_path: str, fm: FileManager, ) -> None: """gen_pyi() This function generates a pyi file for torch. """ # Some of this logic overlaps with generate_python_signature in # tools/autograd/gen_python_functions.py; however, this # function is all about generating mypy type signatures, whereas # the other function generates are custom format for argument # checking. If you are update this, consider if your change # also needs to update the other file. # Dictionary for NamedTuple definitions namedtuples: Dict[str, str] = {} # Generate type signatures for top-level functions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_function_hints: Dict[str, List[str]] = collections.defaultdict(list) for n, n1, n2 in [ ("csr", "crow", "col"), ("csc", "ccol", "row"), ("bsr", "crow", "col"), ("bsc", "ccol", "row"), ]: unsorted_function_hints.update( { f"sparse_{n}_tensor": [ f"def sparse_{n}_tensor({{}}) -> Tensor: ...".format( ", ".join( [ f"{n1}_indices: Union[Tensor, List]", f"{n2}_indices: Union[Tensor, List]", "values: Union[Tensor, List]", "size: Optional[_size] = None", "*", "dtype: Optional[_dtype] = None", "device: Union[_device, str, None] = None", "requires_grad: _bool = False", "check_invariants: Optional[_bool] = None", ] ), ) ], } ) unsorted_function_hints.update( { "set_flush_denormal": ["def set_flush_denormal(mode: _bool) -> _bool: ..."], "get_default_dtype": ["def get_default_dtype() -> _dtype: ..."], "asarray": [ "def asarray({}) -> Tensor: ...".format( ", ".join( [ "obj: Any", "*", "dtype: Optional[_dtype] = None", "device: Union[_device, str, None] = None", "copy: Optional[_bool] = None", "requires_grad: _bool = False", ] ) ) ], "from_numpy": ["def from_numpy(ndarray) -> Tensor: ..."], "frombuffer": [ "def frombuffer({}) -> Tensor: ...".format( ", ".join( [ "buffer: Any", "*", "dtype: _dtype", "count: int = -1", "offset: int = 0", "device: Union[_device, str, None] = None", "requires_grad: _bool = False", ] ) ) ], "numel": ["def numel(self: Tensor) -> _int: ..."], "as_tensor": [ "def as_tensor({}) -> Tensor: ...".format( ", ".join( [ "data: Any", "dtype: Optional[_dtype] = None", DEVICE_PARAM, ] ) ) ], "get_num_threads": ["def get_num_threads() -> _int: ..."], "set_num_threads": ["def set_num_threads(num: _int) -> None: ..."], "init_num_threads": ["def init_num_threads() -> None: ..."], "get_num_interop_threads": ["def get_num_interop_threads() -> _int: ..."], "set_num_interop_threads": [ "def set_num_interop_threads(num: _int) -> None: ..." ], # These functions are explicitly disabled by # SKIP_PYTHON_BINDINGS because they are hand bound. # Correspondingly, we must hand-write their signatures. "tensor": [f"def tensor(data: Any, {FACTORY_PARAMS}) -> Tensor: ..."], "sparse_coo_tensor": [ "def sparse_coo_tensor({}) -> Tensor: ...".format( ", ".join( [ "indices: Tensor", "values: Union[Tensor, List]", "size: Optional[_size] = None", "*", "dtype: Optional[_dtype] = None", "device: Union[_device, str, None] = None", "requires_grad: _bool = False", "check_invariants: Optional[_bool] = None", ] ) ) ], "sparse_compressed_tensor": [ "def sparse_compressed_tensor({}) -> Tensor: ...".format( ", ".join( [ "compressed_indices: Union[Tensor, List]", "plain_indices: Union[Tensor, List]", "values: Union[Tensor, List]", "size: Optional[_size] = None", "*", "dtype: Optional[_dtype] = None", "layout: Optional[_layout] = None", "device: Union[_device, str, None] = None", "requires_grad: _bool = False", "check_invariants: Optional[_bool] = None", ] ) ) ], "_sync": ["def _sync(t: Tensor) -> None: ..."], "_is_functional_tensor": [ "def _is_functional_tensor(t: Tensor) -> _bool: ..." ], "_from_functional_tensor": [ "def _from_functional_tensor(t: Tensor) -> Tensor: ..." ], "_to_functional_tensor": [ "def _to_functional_tensor(t: Tensor) -> Tensor: ..." ], "_enable_functionalization": [ "def _enable_functionalization(*, reapply_views: _bool = False): ..." ], "_disable_functionalization": ["def _disable_functionalization(): ..."], "range": [ "def range({}) -> Tensor: ...".format( ", ".join( [ "start: Number", "end: Number", "step: Number = 1", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ) ], "arange": [ "def arange({}) -> Tensor: ...".format( ", ".join( [ "start: Number", "end: Number", "step: Number", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ), "def arange({}) -> Tensor: ...".format( ", ".join( [ "start: Number", "end: Number", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ), "def arange({}) -> Tensor: ...".format( ", ".join( [ "end: Number", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ), ], "linspace": [ "def linspace({}) -> Tensor: ...".format( ", ".join( [ "start: Number", "end: Number", "steps: Optional[_int] = None", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ) ], "logspace": [ "def logspace({}) -> Tensor: ...".format( ", ".join( [ "start: Number", "end: Number", "steps: Optional[_int] = None", "base: _float = 10.0", "*", "out: Optional[Tensor] = None", FACTORY_PARAMS, ] ) ) ], "randint": [ "def randint({}) -> Tensor: ...".format( ", ".join( [ "low: _int", "high: _int", "size: _size", "*", "generator: Optional[Generator] = None", FACTORY_PARAMS, ] ) ), "def randint({}) -> Tensor: ...".format( ", ".join( [ "high: _int", "size: _size", "*", "generator: Optional[Generator] = None", FACTORY_PARAMS, ] ) ), ], "full": [ "def full({}) -> Tensor: ...".format( ", ".join( [ "size: _size", "fill_value: Union[Number, _complex]", "*", "out: Optional[Tensor] = None", "layout: _layout = strided", FACTORY_PARAMS, ] ) ), "def full({}) -> Tensor: ...".format( ", ".join( [ "size: _size", "fill_value: Union[Number, _complex]", "*", "names: List[Union[str, None]]", "layout: _layout = strided", FACTORY_PARAMS, ] ) ), ], "is_grad_enabled": ["def is_grad_enabled() -> _bool: ..."], "is_inference_mode_enabled": [ "def is_inference_mode_enabled() -> _bool: ..." ], "nonzero": [ "def nonzero(input: Tensor, *, as_tuple: Literal[False] = False, out: Optional[Tensor] = None) -> Tensor: ...", "def nonzero(input: Tensor, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", ], "dsmm": ["def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "hsmm": ["def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "saddmm": [ "def saddmm({}) -> Tensor: ...".format( ", ".join( [ "input: Tensor", "mat1: Tensor", "mat2: Tensor", "*", "beta: Number = 1", "alpha: Number = 1", "out: Optional[Tensor] = None", ] ) ) ], "spmm": ["def spmm(input: Tensor, mat2: Tensor) -> Tensor: ..."], "div": [ "def div({}) -> Tensor: ...".format( ", ".join( [ "input: Union[Tensor, Number]", "other: Union[Tensor, Number]", "*", "rounding_mode: Optional[str] = None", "out: Optional[Tensor] = None", ] ) ) ], } ) for binop in ["mul", "true_divide", "floor_divide"]: unsorted_function_hints[binop].append( f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " "*, out: Optional[Tensor] = None) -> Tensor: ..." ) for binop in ["add", "sub"]: unsorted_function_hints[binop].append( f"def {binop}(input: Union[Tensor, Number], other: Union[Tensor, Number], " "*, alpha: Optional[Number] = 1, out: Optional[Tensor] = None) -> Tensor: ..." ) native_functions = parse_native_yaml( native_yaml_path, tags_yaml_path ).native_functions native_functions = list(filter(should_generate_py_binding, native_functions)) function_signatures = load_signatures( native_functions, deprecated_yaml_path, method=False, pyi=True ) sig_groups = get_py_torch_functions(function_signatures) for group in sorted(sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_function_hints[name] += generate_type_hints(group) named_tuple = returns_named_tuple_pyi(group.signature) if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def def replace_special_case(hint: str) -> str: # NB: Keep this in sync with enum in aten/src/ATen/core/Reduction.h hint = hint.replace("at::Reduction::Mean", "1") hint = hint.replace(": Tensor = None", ": Optional[Tensor] = None") # Match both: # ": Union[Tensor, Tuple[Tensor, ...], List[Tensor]] = None" # ": Union[Tuple[Tensor, ...], List[Tensor]] = None" hint = hint.replace( "Tuple[Tensor, ...], List[Tensor]] = None", "Tuple[Tensor, ...], List[Tensor], None] = None", ) return hint function_hints = [] for name, hints in sorted(unsorted_function_hints.items()): hints = [replace_special_case(h) for h in hints] if len(hints) > 1: hints = ["@overload\n" + h for h in hints] function_hints += hints # Generate type signatures for Tensor methods # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ unsorted_tensor_method_hints: Dict[str, List[str]] = collections.defaultdict(list) unsorted_tensor_method_hints.update( { "size": [ "def size(self) -> Size: ...", "def size(self, dim: _int) -> _int: ...", ], "stride": [ "def stride(self) -> Tuple[_int, ...]: ...", "def stride(self, _int) -> _int: ...", ], "new_ones": [ f"def new_ones(self, size: _size, {FACTORY_PARAMS}) -> Tensor: ..." ], "new_tensor": [ f"def new_tensor(self, data: Any, {FACTORY_PARAMS}) -> Tensor: ..." ], # new and __init__ have the same signatures differ only in return type # Adapted from legacy_tensor_ctor and legacy_tensor_new "new": [ f"def new(self, *args: Any, {DEVICE_PARAM}) ->Tensor: ...", "def new(self, storage: Storage) -> Tensor: ...", "def new(self, other: Tensor) -> Tensor: ...", f"def new(self, size: _size, *, {DEVICE_PARAM}) -> Tensor: ...", ], "__init__": [ f"def __init__(self, *args: Any, {DEVICE_PARAM}) -> None: ...", "def __init__(self, storage: Storage) -> None: ...", "def __init__(self, other: Tensor) -> None: ...", f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...", ], "as_subclass": ["def as_subclass(self, cls: Type[S]) -> S: ..."], "_make_subclass": [ "@staticmethod \ndef _make_subclass({}) -> S: ...".format( ", ".join( [ "cls: Type[S]", "data: Tensor", "require_grad: _bool = False", "dispatch_strides: _bool = False", "dispatch_device: _bool = False", "device_for_backend_keys: Optional[_device] = None", ] ) ) ], "__getitem__": [f"def __getitem__(self, {INDICES}) -> Tensor: ..."], "__setitem__": [ f"def __setitem__(self, {INDICES}, val: Union[Tensor, Number]) -> None: ..." ], "tolist": ["def tolist(self) -> List: ..."], "requires_grad_": [ "def requires_grad_(self, mode: _bool = True) -> Tensor: ..." ], "element_size": ["def element_size(self) -> _int: ..."], "data_ptr": ["def data_ptr(self) -> _int: ..."], "dim": ["def dim(self) -> _int: ..."], "nonzero": [ "def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...", "def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...", ], "numel": ["def numel(self) -> _int: ..."], "ndimension": ["def ndimension(self) -> _int: ..."], "nelement": ["def nelement(self) -> _int: ..."], "cuda": [ "def cuda({}) -> Tensor: ...".format( ", ".join( [ "self", "device: Optional[Union[_device, _int, str]] = None", "non_blocking: _bool = False", ] ) ) ], "numpy": ["def numpy(self, *, force: _bool = False) -> Any: ..."], "apply_": ["def apply_(self, callable: Callable) -> Tensor: ..."], "map_": [ "def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ..." ], "map2_": [ "def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ..." ], "storage": ["def untyped_storage(self) -> UntypedStorage: ..."], "storage_type": ["def storage_type(self) -> Storage: ..."], "type": [ "def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...", "def type(self, dtype: Union[str, _dtype], non_blocking: _bool = False) -> Tensor: ...", ], "get_device": ["def get_device(self) -> _int: ..."], "contiguous": [ "def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ..." ], "has_names": ["def has_names(self) -> _bool: ..."], "is_contiguous": [ "def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ..." ], "_is_view": ["def _is_view(self) -> _bool: ..."], "is_cpu": ["is_cpu: _bool"], "is_cuda": ["is_cuda: _bool"], "is_leaf": ["is_leaf: _bool"], "is_nested": ["is_nested: _bool"], "is_sparse": ["is_sparse: _bool"], "is_sparse_csr": ["is_sparse_csr: _bool"], "is_quantized": ["is_quantized: _bool"], "is_meta": ["is_meta: _bool"], "is_mps": ["is_mps: _bool"], "is_ort": ["is_ort: _bool"], "is_mkldnn": ["is_mkldnn: _bool"], "is_vulkan": ["is_vulkan: _bool"], "is_ipu": ["is_ipu: _bool"], "storage_offset": ["def storage_offset(self) -> _int: ..."], "to": [ "def to(self, dtype: _dtype, non_blocking: _bool = False, copy: _bool = False) -> Tensor: ...", "def to({}) -> Tensor: ...".format( ", ".join( [ "self", "device: Optional[Union[_device, str]] = None", "dtype: Optional[_dtype] = None", "non_blocking: _bool = False", "copy: _bool = False", ] ) ), "def to(self, other: Tensor, non_blocking: _bool = False, copy: _bool = False) -> Tensor: ...", ], "item": ["def item(self) -> Number: ..."], "copy_": [ "def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ..." ], "set_": [ "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], " "offset: _int, size: _size, stride: _size) -> Tensor: ...", "def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...", ], "split": [ "def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...", "def split(self, split_size: Tuple[_int, ...], dim: _int = 0) -> Sequence[Tensor]: ...", ], "div": [ "def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." ], "div_": [ "def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..." ], } ) for binop in ["mul", "true_divide", "floor_divide"]: for inplace in [False, True]: out_suffix = ", *, out: Optional[Tensor] = None" if inplace: binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]{out_suffix})" " -> Tensor: ..." ) for binop in ["add", "sub"]: for inplace in [False, True]: out_suffix = ", out: Optional[Tensor] = None" if inplace: binop += "_" out_suffix = "" unsorted_tensor_method_hints[binop].append( f"def {binop}(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], " f"*, alpha: Optional[Number] = 1{out_suffix})" " -> Tensor: ..." ) simple_conversions = [ "byte", "char", "cpu", "double", "float", "half", "int", "long", "short", "bool", "bfloat16", ] for name in simple_conversions: unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...") # pyi tensor methods don't currently include deprecated signatures for some reason # TODO: we should probably add them in tensor_method_signatures = load_signatures( native_functions, deprecated_yaml_path, method=True, skip_deprecated=True, pyi=True, ) tensor_method_sig_groups = get_py_torch_functions( tensor_method_signatures, method=True ) for group in sorted(tensor_method_sig_groups, key=lambda g: g.signature.name): name = group.signature.name unsorted_tensor_method_hints[name] += generate_type_hints(group) named_tuple = returns_named_tuple_pyi(group.signature) if named_tuple is not None and not group.signature.deprecated: # deprecated namedtuples are currently not included for torch functions tuple_name, tuple_def = named_tuple if tuple_name in namedtuples: assert namedtuples[tuple_name] == tuple_def else: namedtuples[tuple_name] = tuple_def for op in all_ops: name = f"__{op}__" unsorted_tensor_method_hints[name] += sig_for_ops(name) tensor_method_hints = [] for name, hints in sorted(unsorted_tensor_method_hints.items()): if len(hints) > 1: hints = ["@overload\n" + h for h in hints] tensor_method_hints += hints # TODO: Missing type hints for nn # Generate namedtuple definitions # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namedtuple_defs = [f"{defn}\n" for defn in namedtuples.values()] # Generate type signatures for legacy classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ legacy_storage_base_hints = ["class StorageBase(object): ..."] legacy_class_hints = [] for c in ( "DoubleTensor", "FloatTensor", "LongTensor", "IntTensor", "ShortTensor", "HalfTensor", "CharTensor", "ByteTensor", "BoolTensor", ): legacy_class_hints.append(f"class {c}(Tensor): ...") # Generate type signatures for dtype classes # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: don't explicitly list dtypes here; get it from canonical # source dtype_class_hints = [ f"{n}: dtype = ..." for n in [ "float32", "float", "float64", "double", "float16", "bfloat16", "half", "uint8", "int8", "int16", "short", "int32", "int", "int64", "long", "complex32", "complex64", "cfloat", "complex128", "cdouble", "quint8", "qint8", "qint32", "bool", "quint4x2", "quint2x4", ] ] # Generate __all__ directive # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Include only the functions that contain hints, to prevent undefined # symbols to be included in the `__all__` directive. hinted_function_names = [ name for name, hint in unsorted_function_hints.items() if hint ] all_symbols = sorted(list(namedtuples.keys()) + hinted_function_names) all_directive = pformat(all_symbols, width=100, compact=True).split("\n") all_directive[0] = f"__all__ = {all_directive[0]}" # Dispatch key hints # ~~~~~~~~~~~~~~~~~~ dispatch_key_hints = [f"{d.name}: DispatchKey = ..." for d in DispatchKey] # Tags Enum type hints # ~~~~~~~~~~~~~~~~~~~~ tag_names = sorted(parse_tags_yaml(tags_yaml_path)) tag_attributes = "\n".join( f"{name}: _int = {index}" for index, name in enumerate(tag_names) ) # Write out the stub # ~~~~~~~~~~~~~~~~~~ env = { "namedtuple_defs": namedtuple_defs, "function_hints": function_hints, "tensor_method_hints": tensor_method_hints, "legacy_class_hints": legacy_class_hints, "legacy_storage_base_hints": legacy_storage_base_hints, "dtype_class_hints": dtype_class_hints, "dispatch_key_hints": dispatch_key_hints, "all_directive": all_directive, "tag_attributes": tag_attributes, } fm.write_with_template( "torch/_C/__init__.pyi", "torch/_C/__init__.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/__init__.pyi.in", **env, }, ) fm.write_with_template( "torch/_C/_VariableFunctions.pyi", "torch/_C/_VariableFunctions.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/_VariableFunctions.pyi.in", **env, }, ) fm.write_with_template( "torch/_VF.pyi", "torch/_C/_VariableFunctions.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/_VariableFunctions.pyi.in", **env, }, ) fm.write_with_template( "torch/return_types.pyi", "torch/_C/return_types.pyi.in", lambda: { "generated_comment": "@" + "generated from torch/_C/return_types.pyi", **env, }, ) gen_nn_functional(fm) def main() -> None: parser = argparse.ArgumentParser(description="Generate type stubs for PyTorch") parser.add_argument( "--native-functions-path", metavar="NATIVE", default="aten/src/ATen/native/native_functions.yaml", help="path to native_functions.yaml", ) parser.add_argument( "--tags-path", metavar="TAGS", default="aten/src/ATen/native/tags.yaml", help="path to tags.yaml", ) parser.add_argument( "--deprecated-functions-path", metavar="DEPRECATED", default="tools/autograd/deprecated.yaml", help="path to deprecated.yaml", ) parser.add_argument( "--out", metavar="OUT", default=".", help="path to output directory" ) args = parser.parse_args() fm = FileManager(install_dir=args.out, template_dir=".", dry_run=False) gen_pyi( args.native_functions_path, args.tags_path, args.deprecated_functions_path, fm ) if __name__ == "__main__": main()